Add kwargs to Transform.apply for num_epochs
Change: 128989804
This commit is contained in:
parent
52c0418614
commit
48e869f0e3
@ -117,10 +117,11 @@ class DataFrame(object):
|
|||||||
value = [value]
|
value = [value]
|
||||||
self.assign(**dict(zip(key, value)))
|
self.assign(**dict(zip(key, value)))
|
||||||
|
|
||||||
def build(self):
|
def build(self, **kwargs):
|
||||||
# We do not allow passing a cache here, because that would encourage
|
# We do not allow passing a cache here, because that would encourage
|
||||||
# working around the rule that DataFrames cannot be expected to be
|
# working around the rule that DataFrames cannot be expected to be
|
||||||
# synced with each other (e.g., they shuffle independently).
|
# synced with each other (e.g., they shuffle independently).
|
||||||
cache = {}
|
cache = {}
|
||||||
tensors = {name: c.build(cache) for name, c in self._columns.items()}
|
tensors = {name: c.build(cache, **kwargs)
|
||||||
|
for name, c in self._columns.items()}
|
||||||
return tensors
|
return tensors
|
||||||
|
@ -91,7 +91,8 @@ def _build_alternate_universe(
|
|||||||
def to_feature_columns_and_input_fn(dataframe,
|
def to_feature_columns_and_input_fn(dataframe,
|
||||||
base_input_keys_with_defaults,
|
base_input_keys_with_defaults,
|
||||||
feature_keys,
|
feature_keys,
|
||||||
target_keys=None):
|
target_keys=None,
|
||||||
|
**kwargs):
|
||||||
"""Build a list of FeatureColumns and an input_fn for use with Estimator.
|
"""Build a list of FeatureColumns and an input_fn for use with Estimator.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -103,6 +104,7 @@ def to_feature_columns_and_input_fn(dataframe,
|
|||||||
These may include base features and/or derived features.
|
These may include base features and/or derived features.
|
||||||
target_keys: the names of columns to be used as targets. None is
|
target_keys: the names of columns to be used as targets. None is
|
||||||
acceptable for unsupervised learning.
|
acceptable for unsupervised learning.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of two elements:
|
A tuple of two elements:
|
||||||
@ -155,10 +157,11 @@ def to_feature_columns_and_input_fn(dataframe,
|
|||||||
|
|
||||||
# Build an input_fn suitable for use with Estimator.
|
# Build an input_fn suitable for use with Estimator.
|
||||||
def input_fn():
|
def input_fn():
|
||||||
|
"""An input_fn() for feeding the given set of DataFrameColumns."""
|
||||||
# It's important to build all the tensors together in one DataFrame.
|
# It's important to build all the tensors together in one DataFrame.
|
||||||
# If we did df.select() for both key sets and then build those, the two
|
# If we did df.select() for both key sets and then build those, the two
|
||||||
# resulting DataFrames would be shuffled independently.
|
# resulting DataFrames would be shuffled independently.
|
||||||
tensors = limited_dataframe.build()
|
tensors = limited_dataframe.build(**kwargs)
|
||||||
|
|
||||||
base_input_features = {key: tensors[key] for key in base_input_keys}
|
base_input_features = {key: tensors[key] for key in base_input_keys}
|
||||||
targets = {key: tensors[key] for key in target_keys}
|
targets = {key: tensors[key] for key in target_keys}
|
||||||
|
@ -98,7 +98,7 @@ class Series(object):
|
|||||||
return transform_cls
|
return transform_cls
|
||||||
return register
|
return register
|
||||||
|
|
||||||
def build(self, cache):
|
def build(self, cache, **kwargs):
|
||||||
"""Returns a Tensor."""
|
"""Returns a Tensor."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@ -122,7 +122,7 @@ class PredefinedSeries(Series):
|
|||||||
def required_base_features(self):
|
def required_base_features(self):
|
||||||
return {self.name: self.feature_spec}
|
return {self.name: self.feature_spec}
|
||||||
|
|
||||||
def build(self, cache):
|
def build(self, cache, **kwargs):
|
||||||
try:
|
try:
|
||||||
return cache[self.name]
|
return cache[self.name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -171,10 +171,11 @@ class TransformedSeries(Series):
|
|||||||
result.update(s.required_base_features)
|
result.update(s.required_base_features)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def build(self, cache=None):
|
def build(self, cache=None, **kwargs):
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = {}
|
cache = {}
|
||||||
all_outputs = self._transform.build_transitive(self._input_series, cache)
|
all_outputs = self._transform.build_transitive(
|
||||||
|
self._input_series, cache, **kwargs)
|
||||||
return getattr(all_outputs, self._output_name)
|
return getattr(all_outputs, self._output_name)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
@ -83,7 +83,8 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
graph=None,
|
graph=None,
|
||||||
session=None,
|
session=None,
|
||||||
start_queues=True,
|
start_queues=True,
|
||||||
initialize_variables=True):
|
initialize_variables=True,
|
||||||
|
**kwargs):
|
||||||
"""Builds and runs the columns of the `DataFrame` and yields batches.
|
"""Builds and runs the columns of the `DataFrame` and yields batches.
|
||||||
|
|
||||||
This is a generator that yields a dictionary mapping column names to
|
This is a generator that yields a dictionary mapping column names to
|
||||||
@ -97,6 +98,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
start_queues: if true, queues will be started before running and halted
|
start_queues: if true, queues will be started before running and halted
|
||||||
after producting `n` batches.
|
after producting `n` batches.
|
||||||
initialize_variables: if true, variables will be initialized.
|
initialize_variables: if true, variables will be initialized.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
A dictionary, mapping column names to the values resulting from running
|
A dictionary, mapping column names to the values resulting from running
|
||||||
@ -107,7 +109,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
with graph.as_default():
|
with graph.as_default():
|
||||||
if session is None:
|
if session is None:
|
||||||
session = sess.Session()
|
session = sess.Session()
|
||||||
self_built = self.build()
|
self_built = self.build(**kwargs)
|
||||||
keys = list(self_built.keys())
|
keys = list(self_built.keys())
|
||||||
cols = list(self_built.values())
|
cols = list(self_built.values())
|
||||||
if initialize_variables:
|
if initialize_variables:
|
||||||
@ -208,7 +210,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_csv_base(cls, filepatterns, get_default_values, has_header,
|
def _from_csv_base(cls, filepatterns, get_default_values, has_header,
|
||||||
column_names, num_epochs, num_threads, enqueue_size,
|
column_names, num_threads, enqueue_size,
|
||||||
batch_size, queue_capacity, min_after_dequeue, shuffle,
|
batch_size, queue_capacity, min_after_dequeue, shuffle,
|
||||||
seed):
|
seed):
|
||||||
"""Create a `DataFrame` from CSV files.
|
"""Create a `DataFrame` from CSV files.
|
||||||
@ -223,9 +225,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
each column, given the column names.
|
each column, given the column names.
|
||||||
has_header: whether or not the CSV files have headers.
|
has_header: whether or not the CSV files have headers.
|
||||||
column_names: a list of names for the columns in the CSV files.
|
column_names: a list of names for the columns in the CSV files.
|
||||||
num_epochs: the number of times that the reader should loop through all
|
|
||||||
the file names. If set to `None`, then the reader will continue
|
|
||||||
indefinitely.
|
|
||||||
num_threads: the number of readers that will work in parallel.
|
num_threads: the number of readers that will work in parallel.
|
||||||
enqueue_size: block size for each read operation.
|
enqueue_size: block size for each read operation.
|
||||||
batch_size: desired batch size.
|
batch_size: desired batch size.
|
||||||
@ -265,7 +264,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
reader_kwargs=reader_kwargs,
|
reader_kwargs=reader_kwargs,
|
||||||
enqueue_size=enqueue_size,
|
enqueue_size=enqueue_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_epochs=num_epochs,
|
|
||||||
queue_capacity=queue_capacity,
|
queue_capacity=queue_capacity,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
min_after_dequeue=min_after_dequeue,
|
min_after_dequeue=min_after_dequeue,
|
||||||
@ -287,7 +285,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
default_values,
|
default_values,
|
||||||
has_header=True,
|
has_header=True,
|
||||||
column_names=None,
|
column_names=None,
|
||||||
num_epochs=None,
|
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
enqueue_size=None,
|
enqueue_size=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@ -306,9 +303,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
default_values: a list of default values for each column.
|
default_values: a list of default values for each column.
|
||||||
has_header: whether or not the CSV files have headers.
|
has_header: whether or not the CSV files have headers.
|
||||||
column_names: a list of names for the columns in the CSV files.
|
column_names: a list of names for the columns in the CSV files.
|
||||||
num_epochs: the number of times that the reader should loop through all
|
|
||||||
the file names. If set to `None`, then the reader will continue
|
|
||||||
indefinitely.
|
|
||||||
num_threads: the number of readers that will work in parallel.
|
num_threads: the number of readers that will work in parallel.
|
||||||
enqueue_size: block size for each read operation.
|
enqueue_size: block size for each read operation.
|
||||||
batch_size: desired batch size.
|
batch_size: desired batch size.
|
||||||
@ -332,7 +326,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
return default_values
|
return default_values
|
||||||
|
|
||||||
return cls._from_csv_base(filepatterns, get_default_values, has_header,
|
return cls._from_csv_base(filepatterns, get_default_values, has_header,
|
||||||
column_names, num_epochs, num_threads,
|
column_names, num_threads,
|
||||||
enqueue_size, batch_size, queue_capacity,
|
enqueue_size, batch_size, queue_capacity,
|
||||||
min_after_dequeue, shuffle, seed)
|
min_after_dequeue, shuffle, seed)
|
||||||
|
|
||||||
@ -342,7 +336,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
feature_spec,
|
feature_spec,
|
||||||
has_header=True,
|
has_header=True,
|
||||||
column_names=None,
|
column_names=None,
|
||||||
num_epochs=None,
|
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
enqueue_size=None,
|
enqueue_size=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@ -362,9 +355,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
`VarLenFeature`.
|
`VarLenFeature`.
|
||||||
has_header: whether or not the CSV files have headers.
|
has_header: whether or not the CSV files have headers.
|
||||||
column_names: a list of names for the columns in the CSV files.
|
column_names: a list of names for the columns in the CSV files.
|
||||||
num_epochs: the number of times that the reader should loop through all
|
|
||||||
the file names. If set to `None`, then the reader will continue
|
|
||||||
indefinitely.
|
|
||||||
num_threads: the number of readers that will work in parallel.
|
num_threads: the number of readers that will work in parallel.
|
||||||
enqueue_size: block size for each read operation.
|
enqueue_size: block size for each read operation.
|
||||||
batch_size: desired batch size.
|
batch_size: desired batch size.
|
||||||
@ -387,7 +377,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
return [_get_default_value(feature_spec[name]) for name in column_names]
|
return [_get_default_value(feature_spec[name]) for name in column_names]
|
||||||
|
|
||||||
dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header,
|
dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header,
|
||||||
column_names, num_epochs, num_threads,
|
column_names, num_threads,
|
||||||
enqueue_size, batch_size, queue_capacity,
|
enqueue_size, batch_size, queue_capacity,
|
||||||
min_after_dequeue, shuffle, seed)
|
min_after_dequeue, shuffle, seed)
|
||||||
|
|
||||||
@ -405,7 +395,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
filepatterns,
|
filepatterns,
|
||||||
features,
|
features,
|
||||||
reader_cls=io_ops.TFRecordReader,
|
reader_cls=io_ops.TFRecordReader,
|
||||||
num_epochs=None,
|
|
||||||
num_threads=1,
|
num_threads=1,
|
||||||
enqueue_size=None,
|
enqueue_size=None,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
@ -421,9 +410,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
`FixedLenFeature`.
|
`FixedLenFeature`.
|
||||||
reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to
|
reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to
|
||||||
read the `Example`s.
|
read the `Example`s.
|
||||||
num_epochs: the number of times that the reader should loop through all
|
|
||||||
the file names. If set to `None`, then the reader will continue
|
|
||||||
indefinitely.
|
|
||||||
num_threads: the number of readers that will work in parallel.
|
num_threads: the number of readers that will work in parallel.
|
||||||
enqueue_size: block size for each read operation.
|
enqueue_size: block size for each read operation.
|
||||||
batch_size: desired batch size.
|
batch_size: desired batch size.
|
||||||
@ -454,7 +440,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
|||||||
filenames,
|
filenames,
|
||||||
enqueue_size=enqueue_size,
|
enqueue_size=enqueue_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_epochs=num_epochs,
|
|
||||||
queue_capacity=queue_capacity,
|
queue_capacity=queue_capacity,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
min_after_dequeue=min_after_dequeue,
|
min_after_dequeue=min_after_dequeue,
|
||||||
|
@ -223,13 +223,14 @@ class Transform(object):
|
|||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
return self.return_type(*output_series)
|
return self.return_type(*output_series)
|
||||||
|
|
||||||
def build_transitive(self, input_series, cache=None):
|
def build_transitive(self, input_series, cache=None, **kwargs):
|
||||||
"""Apply this `Transform` to the provided `Series`, producing 'Tensor's.
|
"""Apply this `Transform` to the provided `Series`, producing 'Tensor's.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_series: None, a `Series`, or a list of input `Series`, acting as
|
input_series: None, a `Series`, or a list of input `Series`, acting as
|
||||||
positional arguments.
|
positional arguments.
|
||||||
cache: a dict from Series reprs to Tensors.
|
cache: a dict from Series reprs to Tensors.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple of the output Tensors.
|
A namedtuple of the output Tensors.
|
||||||
@ -244,7 +245,7 @@ class Transform(object):
|
|||||||
if len(input_series) != self.input_valency:
|
if len(input_series) != self.input_valency:
|
||||||
raise ValueError("Expected %s input Series but received %s." %
|
raise ValueError("Expected %s input Series but received %s." %
|
||||||
(self.input_valency, len(input_series)))
|
(self.input_valency, len(input_series)))
|
||||||
input_tensors = [series.build(cache) for series in input_series]
|
input_tensors = [series.build(cache, **kwargs) for series in input_series]
|
||||||
|
|
||||||
# Note we cache each output individually, not just the entire output
|
# Note we cache each output individually, not just the entire output
|
||||||
# tuple. This allows using the graph as the cache, since it can sensibly
|
# tuple. This allows using the graph as the cache, since it can sensibly
|
||||||
@ -254,7 +255,7 @@ class Transform(object):
|
|||||||
output_tensors = [cache.get(output_repr) for output_repr in output_reprs]
|
output_tensors = [cache.get(output_repr) for output_repr in output_reprs]
|
||||||
|
|
||||||
if None in output_tensors:
|
if None in output_tensors:
|
||||||
result = self._apply_transform(input_tensors)
|
result = self._apply_transform(input_tensors, **kwargs)
|
||||||
for output_name, output_repr in zip(self.output_names, output_reprs):
|
for output_name, output_repr in zip(self.output_names, output_reprs):
|
||||||
cache[output_repr] = getattr(result, output_name)
|
cache[output_repr] = getattr(result, output_name)
|
||||||
else:
|
else:
|
||||||
@ -264,12 +265,13 @@ class Transform(object):
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
"""Applies the transformation to the `transform_input`.
|
"""Applies the transformation to the `transform_input`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensors: a list of Tensors representing the input to
|
input_tensors: a list of Tensors representing the input to
|
||||||
the Transform.
|
the Transform.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple of Tensors representing the transformed output.
|
A namedtuple of Tensors representing the transformed output.
|
||||||
|
@ -72,7 +72,7 @@ class Batch(AbstractBatchTransform):
|
|||||||
def name(self):
|
def name(self):
|
||||||
return "Batch"
|
return "Batch"
|
||||||
|
|
||||||
def _apply_transform(self, transform_input):
|
def _apply_transform(self, transform_input, **kwargs):
|
||||||
batched = input_ops.batch(transform_input,
|
batched = input_ops.batch(transform_input,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
num_threads=self.num_threads,
|
num_threads=self.num_threads,
|
||||||
@ -121,7 +121,7 @@ class ShuffleBatch(AbstractBatchTransform):
|
|||||||
def seed(self):
|
def seed(self):
|
||||||
return self._seed
|
return self._seed
|
||||||
|
|
||||||
def _apply_transform(self, transform_input):
|
def _apply_transform(self, transform_input, **kwargs):
|
||||||
batched = input_ops.shuffle_batch(transform_input,
|
batched = input_ops.shuffle_batch(transform_input,
|
||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
capacity=self.queue_capacity,
|
capacity=self.queue_capacity,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -53,7 +53,7 @@ class SeriesBinaryTransform(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
# TODO(jamieas): consider supporting sparse inputs.
|
# TODO(jamieas): consider supporting sparse inputs.
|
||||||
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
|
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
|
||||||
input_tensors[1], ops.SparseTensor):
|
input_tensors[1], ops.SparseTensor):
|
||||||
@ -87,7 +87,7 @@ class ScalarBinaryTransform(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
input_tensor = input_tensors[0]
|
input_tensor = input_tensors[0]
|
||||||
if isinstance(input_tensor, ops.SparseTensor):
|
if isinstance(input_tensor, ops.SparseTensor):
|
||||||
result = ops.SparseTensor(input_tensor.indices,
|
result = ops.SparseTensor(input_tensor.indices,
|
||||||
|
@ -77,12 +77,13 @@ class BooleanMask(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
"""Applies the transformation to the `transform_input`.
|
"""Applies the transformation to the `transform_input`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensors: a list of Tensors representing the input to
|
input_tensors: a list of Tensors representing the input to
|
||||||
the Transform.
|
the Transform.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple of Tensors representing the transformed output.
|
A namedtuple of Tensors representing the transformed output.
|
||||||
|
@ -58,7 +58,7 @@ class CSVParser(transform.Transform):
|
|||||||
def default_values(self):
|
def default_values(self):
|
||||||
return self._default_values
|
return self._default_values
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
default_consts = [constant_op.constant(d, shape=[1])
|
default_consts = [constant_op.constant(d, shape=[1])
|
||||||
for d in self._default_values]
|
for d in self._default_values]
|
||||||
parsed_values = parsing_ops.decode_csv(input_tensors[0],
|
parsed_values = parsing_ops.decode_csv(input_tensors[0],
|
||||||
|
@ -47,12 +47,13 @@ class Densify(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
"""Applies the transformation to the `transform_input`.
|
"""Applies the transformation to the `transform_input`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensors: a list of Tensors representing the input to
|
input_tensors: a list of Tensors representing the input to
|
||||||
the Transform.
|
the Transform.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple of Tensors representing the transformed output.
|
A namedtuple of Tensors representing the transformed output.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -50,7 +50,7 @@ class Difference(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
||||||
isinstance(input_tensors[1], ops.SparseTensor))
|
isinstance(input_tensors[1], ops.SparseTensor))
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class ExampleParser(transform.Transform):
|
|||||||
def feature_definitions(self):
|
def feature_definitions(self):
|
||||||
return self._ordered_features
|
return self._ordered_features
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
parsed_values = parsing_ops.parse_example(input_tensors[0],
|
parsed_values = parsing_ops.parse_example(input_tensors[0],
|
||||||
features=self._ordered_features)
|
features=self._ordered_features)
|
||||||
# pylint: disable=not-callable
|
# pylint: disable=not-callable
|
||||||
|
@ -89,7 +89,7 @@ class BaseInMemorySource(transform.Transform):
|
|||||||
def input_valency(self):
|
def input_valency(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def _apply_transform(self, transform_input):
|
def _apply_transform(self, transform_input, **kwargs):
|
||||||
queue = feeding_functions.enqueue_data(self.data,
|
queue = feeding_functions.enqueue_data(self.data,
|
||||||
self.queue_capacity,
|
self.queue_capacity,
|
||||||
self.shuffle,
|
self.shuffle,
|
||||||
|
@ -32,7 +32,6 @@ class ReaderSource(transform.Transform):
|
|||||||
reader_kwargs=None,
|
reader_kwargs=None,
|
||||||
enqueue_size=None,
|
enqueue_size=None,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_epochs=None,
|
|
||||||
queue_capacity=None,
|
queue_capacity=None,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
min_after_dequeue=None,
|
min_after_dequeue=None,
|
||||||
@ -49,9 +48,6 @@ class ReaderSource(transform.Transform):
|
|||||||
is constructed.
|
is constructed.
|
||||||
enqueue_size: block size for each read operation.
|
enqueue_size: block size for each read operation.
|
||||||
batch_size: The desired batch size of output. Defaults to 1.
|
batch_size: The desired batch size of output. Defaults to 1.
|
||||||
num_epochs: the number of times that the reader should loop through all
|
|
||||||
the file names. If set to `None`, then the reader will continue
|
|
||||||
indefinitely.
|
|
||||||
queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`.
|
queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`.
|
||||||
shuffle: Whether records will be shuffled before returning. Defaults to
|
shuffle: Whether records will be shuffled before returning. Defaults to
|
||||||
false.
|
false.
|
||||||
@ -73,7 +69,6 @@ class ReaderSource(transform.Transform):
|
|||||||
self._batch_size = batch_size
|
self._batch_size = batch_size
|
||||||
self._queue_capacity = (batch_size * 10 if queue_capacity is None else
|
self._queue_capacity = (batch_size * 10 if queue_capacity is None else
|
||||||
queue_capacity)
|
queue_capacity)
|
||||||
self._num_epochs = num_epochs
|
|
||||||
self._shuffle = shuffle
|
self._shuffle = shuffle
|
||||||
self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue
|
self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue
|
||||||
is None else min_after_dequeue)
|
is None else min_after_dequeue)
|
||||||
@ -100,10 +95,6 @@ class ReaderSource(transform.Transform):
|
|||||||
def batch_size(self):
|
def batch_size(self):
|
||||||
return self._batch_size
|
return self._batch_size
|
||||||
|
|
||||||
@transform.parameter
|
|
||||||
def num_epochs(self):
|
|
||||||
return self._num_epochs
|
|
||||||
|
|
||||||
@transform.parameter
|
@transform.parameter
|
||||||
def queue_capacity(self):
|
def queue_capacity(self):
|
||||||
return self._queue_capacity
|
return self._queue_capacity
|
||||||
@ -136,11 +127,12 @@ class ReaderSource(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return ("index", "value")
|
return ("index", "value")
|
||||||
|
|
||||||
def _apply_transform(self, transform_input):
|
def _apply_transform(self, transform_input, **kwargs):
|
||||||
filename_queue = input_ops.string_input_producer(self.work_units,
|
filename_queue = input_ops.string_input_producer(
|
||||||
num_epochs=self.num_epochs,
|
self.work_units,
|
||||||
shuffle=self.shuffle,
|
num_epochs=kwargs.get("num_epochs"),
|
||||||
seed=self.seed)
|
shuffle=self.shuffle,
|
||||||
|
seed=self.seed)
|
||||||
reader_ops = []
|
reader_ops = []
|
||||||
for _ in range(self.num_threads):
|
for _ in range(self.num_threads):
|
||||||
reader = self._reader_cls(**self._reader_kwargs)
|
reader = self._reader_cls(**self._reader_kwargs)
|
||||||
@ -174,7 +166,6 @@ def TextFileSource(file_names,
|
|||||||
reader_kwargs=None,
|
reader_kwargs=None,
|
||||||
enqueue_size=1,
|
enqueue_size=1,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_epochs=None,
|
|
||||||
queue_capacity=None,
|
queue_capacity=None,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
min_after_dequeue=None,
|
min_after_dequeue=None,
|
||||||
@ -185,7 +176,6 @@ def TextFileSource(file_names,
|
|||||||
reader_kwargs=reader_kwargs,
|
reader_kwargs=reader_kwargs,
|
||||||
enqueue_size=enqueue_size,
|
enqueue_size=enqueue_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_epochs=num_epochs,
|
|
||||||
queue_capacity=queue_capacity,
|
queue_capacity=queue_capacity,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
min_after_dequeue=min_after_dequeue,
|
min_after_dequeue=min_after_dequeue,
|
||||||
@ -197,7 +187,6 @@ def TFRecordSource(file_names,
|
|||||||
reader_kwargs=None,
|
reader_kwargs=None,
|
||||||
enqueue_size=1,
|
enqueue_size=1,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
num_epochs=None,
|
|
||||||
queue_capacity=None,
|
queue_capacity=None,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
min_after_dequeue=None,
|
min_after_dequeue=None,
|
||||||
@ -208,7 +197,6 @@ def TFRecordSource(file_names,
|
|||||||
reader_kwargs=reader_kwargs,
|
reader_kwargs=reader_kwargs,
|
||||||
enqueue_size=enqueue_size,
|
enqueue_size=enqueue_size,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_epochs=num_epochs,
|
|
||||||
queue_capacity=queue_capacity,
|
queue_capacity=queue_capacity,
|
||||||
shuffle=shuffle,
|
shuffle=shuffle,
|
||||||
min_after_dequeue=min_after_dequeue,
|
min_after_dequeue=min_after_dequeue,
|
||||||
|
@ -52,12 +52,13 @@ class Sparsify(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
"""Applies the transformation to the `transform_input`.
|
"""Applies the transformation to the `transform_input`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_tensors: a list of Tensors representing the input to
|
input_tensors: a list of Tensors representing the input to
|
||||||
the Transform.
|
the Transform.
|
||||||
|
**kwargs: Additional keyword arguments, unused here.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A namedtuple of Tensors representing the transformed output.
|
A namedtuple of Tensors representing the transformed output.
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -44,7 +44,7 @@ class Sum(transform.Transform):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output",
|
return "output",
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
||||||
isinstance(input_tensors[1], ops.SparseTensor))
|
isinstance(input_tensors[1], ops.SparseTensor))
|
||||||
|
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -78,7 +78,7 @@ def register_unary_op(registered_name, operation):
|
|||||||
def _output_names(self):
|
def _output_names(self):
|
||||||
return "output"
|
return "output"
|
||||||
|
|
||||||
def _apply_transform(self, input_tensors):
|
def _apply_transform(self, input_tensors, **kwargs):
|
||||||
input_tensor = input_tensors[0]
|
input_tensor = input_tensors[0]
|
||||||
if isinstance(input_tensor, ops.SparseTensor):
|
if isinstance(input_tensor, ops.SparseTensor):
|
||||||
result = ops.SparseTensor(input_tensor.indices,
|
result = ops.SparseTensor(input_tensor.indices,
|
||||||
|
@ -208,10 +208,9 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
|
|||||||
tensorflow_df = df.TensorFlowDataFrame.from_csv(
|
tensorflow_df = df.TensorFlowDataFrame.from_csv(
|
||||||
[data_path],
|
[data_path],
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_epochs=num_epochs,
|
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
default_values=default_values)
|
default_values=default_values)
|
||||||
actual_num_batches = len(list(tensorflow_df.run()))
|
actual_num_batches = len(list(tensorflow_df.run(num_epochs=num_epochs)))
|
||||||
self.assertEqual(expected_num_batches, actual_num_batches)
|
self.assertEqual(expected_num_batches, actual_num_batches)
|
||||||
|
|
||||||
def testFromCSVWithFeatureSpec(self):
|
def testFromCSVWithFeatureSpec(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user