Add kwargs to Transform.apply for num_epochs
Change: 128989804
This commit is contained in:
parent
52c0418614
commit
48e869f0e3
@ -117,10 +117,11 @@ class DataFrame(object):
|
||||
value = [value]
|
||||
self.assign(**dict(zip(key, value)))
|
||||
|
||||
def build(self):
|
||||
def build(self, **kwargs):
|
||||
# We do not allow passing a cache here, because that would encourage
|
||||
# working around the rule that DataFrames cannot be expected to be
|
||||
# synced with each other (e.g., they shuffle independently).
|
||||
cache = {}
|
||||
tensors = {name: c.build(cache) for name, c in self._columns.items()}
|
||||
tensors = {name: c.build(cache, **kwargs)
|
||||
for name, c in self._columns.items()}
|
||||
return tensors
|
||||
|
@ -91,7 +91,8 @@ def _build_alternate_universe(
|
||||
def to_feature_columns_and_input_fn(dataframe,
|
||||
base_input_keys_with_defaults,
|
||||
feature_keys,
|
||||
target_keys=None):
|
||||
target_keys=None,
|
||||
**kwargs):
|
||||
"""Build a list of FeatureColumns and an input_fn for use with Estimator.
|
||||
|
||||
Args:
|
||||
@ -103,6 +104,7 @@ def to_feature_columns_and_input_fn(dataframe,
|
||||
These may include base features and/or derived features.
|
||||
target_keys: the names of columns to be used as targets. None is
|
||||
acceptable for unsupervised learning.
|
||||
**kwargs: Additional keyword arguments, unused here.
|
||||
|
||||
Returns:
|
||||
A tuple of two elements:
|
||||
@ -155,10 +157,11 @@ def to_feature_columns_and_input_fn(dataframe,
|
||||
|
||||
# Build an input_fn suitable for use with Estimator.
|
||||
def input_fn():
|
||||
"""An input_fn() for feeding the given set of DataFrameColumns."""
|
||||
# It's important to build all the tensors together in one DataFrame.
|
||||
# If we did df.select() for both key sets and then build those, the two
|
||||
# resulting DataFrames would be shuffled independently.
|
||||
tensors = limited_dataframe.build()
|
||||
tensors = limited_dataframe.build(**kwargs)
|
||||
|
||||
base_input_features = {key: tensors[key] for key in base_input_keys}
|
||||
targets = {key: tensors[key] for key in target_keys}
|
||||
|
@ -98,7 +98,7 @@ class Series(object):
|
||||
return transform_cls
|
||||
return register
|
||||
|
||||
def build(self, cache):
|
||||
def build(self, cache, **kwargs):
|
||||
"""Returns a Tensor."""
|
||||
raise NotImplementedError()
|
||||
|
||||
@ -122,7 +122,7 @@ class PredefinedSeries(Series):
|
||||
def required_base_features(self):
|
||||
return {self.name: self.feature_spec}
|
||||
|
||||
def build(self, cache):
|
||||
def build(self, cache, **kwargs):
|
||||
try:
|
||||
return cache[self.name]
|
||||
except KeyError:
|
||||
@ -171,10 +171,11 @@ class TransformedSeries(Series):
|
||||
result.update(s.required_base_features)
|
||||
return result
|
||||
|
||||
def build(self, cache=None):
|
||||
def build(self, cache=None, **kwargs):
|
||||
if cache is None:
|
||||
cache = {}
|
||||
all_outputs = self._transform.build_transitive(self._input_series, cache)
|
||||
all_outputs = self._transform.build_transitive(
|
||||
self._input_series, cache, **kwargs)
|
||||
return getattr(all_outputs, self._output_name)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -83,7 +83,8 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
graph=None,
|
||||
session=None,
|
||||
start_queues=True,
|
||||
initialize_variables=True):
|
||||
initialize_variables=True,
|
||||
**kwargs):
|
||||
"""Builds and runs the columns of the `DataFrame` and yields batches.
|
||||
|
||||
This is a generator that yields a dictionary mapping column names to
|
||||
@ -97,6 +98,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
start_queues: if true, queues will be started before running and halted
|
||||
after producting `n` batches.
|
||||
initialize_variables: if true, variables will be initialized.
|
||||
**kwargs: Additional keyword arguments, unused here.
|
||||
|
||||
Yields:
|
||||
A dictionary, mapping column names to the values resulting from running
|
||||
@ -107,7 +109,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
with graph.as_default():
|
||||
if session is None:
|
||||
session = sess.Session()
|
||||
self_built = self.build()
|
||||
self_built = self.build(**kwargs)
|
||||
keys = list(self_built.keys())
|
||||
cols = list(self_built.values())
|
||||
if initialize_variables:
|
||||
@ -208,7 +210,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
|
||||
@classmethod
|
||||
def _from_csv_base(cls, filepatterns, get_default_values, has_header,
|
||||
column_names, num_epochs, num_threads, enqueue_size,
|
||||
column_names, num_threads, enqueue_size,
|
||||
batch_size, queue_capacity, min_after_dequeue, shuffle,
|
||||
seed):
|
||||
"""Create a `DataFrame` from CSV files.
|
||||
@ -223,9 +225,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
each column, given the column names.
|
||||
has_header: whether or not the CSV files have headers.
|
||||
column_names: a list of names for the columns in the CSV files.
|
||||
num_epochs: the number of times that the reader should loop through all
|
||||
the file names. If set to `None`, then the reader will continue
|
||||
indefinitely.
|
||||
num_threads: the number of readers that will work in parallel.
|
||||
enqueue_size: block size for each read operation.
|
||||
batch_size: desired batch size.
|
||||
@ -265,7 +264,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
reader_kwargs=reader_kwargs,
|
||||
enqueue_size=enqueue_size,
|
||||
batch_size=batch_size,
|
||||
num_epochs=num_epochs,
|
||||
queue_capacity=queue_capacity,
|
||||
shuffle=shuffle,
|
||||
min_after_dequeue=min_after_dequeue,
|
||||
@ -287,7 +285,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
default_values,
|
||||
has_header=True,
|
||||
column_names=None,
|
||||
num_epochs=None,
|
||||
num_threads=1,
|
||||
enqueue_size=None,
|
||||
batch_size=32,
|
||||
@ -306,9 +303,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
default_values: a list of default values for each column.
|
||||
has_header: whether or not the CSV files have headers.
|
||||
column_names: a list of names for the columns in the CSV files.
|
||||
num_epochs: the number of times that the reader should loop through all
|
||||
the file names. If set to `None`, then the reader will continue
|
||||
indefinitely.
|
||||
num_threads: the number of readers that will work in parallel.
|
||||
enqueue_size: block size for each read operation.
|
||||
batch_size: desired batch size.
|
||||
@ -332,7 +326,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
return default_values
|
||||
|
||||
return cls._from_csv_base(filepatterns, get_default_values, has_header,
|
||||
column_names, num_epochs, num_threads,
|
||||
column_names, num_threads,
|
||||
enqueue_size, batch_size, queue_capacity,
|
||||
min_after_dequeue, shuffle, seed)
|
||||
|
||||
@ -342,7 +336,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
feature_spec,
|
||||
has_header=True,
|
||||
column_names=None,
|
||||
num_epochs=None,
|
||||
num_threads=1,
|
||||
enqueue_size=None,
|
||||
batch_size=32,
|
||||
@ -362,9 +355,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
`VarLenFeature`.
|
||||
has_header: whether or not the CSV files have headers.
|
||||
column_names: a list of names for the columns in the CSV files.
|
||||
num_epochs: the number of times that the reader should loop through all
|
||||
the file names. If set to `None`, then the reader will continue
|
||||
indefinitely.
|
||||
num_threads: the number of readers that will work in parallel.
|
||||
enqueue_size: block size for each read operation.
|
||||
batch_size: desired batch size.
|
||||
@ -387,7 +377,7 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
return [_get_default_value(feature_spec[name]) for name in column_names]
|
||||
|
||||
dataframe = cls._from_csv_base(filepatterns, get_default_values, has_header,
|
||||
column_names, num_epochs, num_threads,
|
||||
column_names, num_threads,
|
||||
enqueue_size, batch_size, queue_capacity,
|
||||
min_after_dequeue, shuffle, seed)
|
||||
|
||||
@ -405,7 +395,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
filepatterns,
|
||||
features,
|
||||
reader_cls=io_ops.TFRecordReader,
|
||||
num_epochs=None,
|
||||
num_threads=1,
|
||||
enqueue_size=None,
|
||||
batch_size=32,
|
||||
@ -421,9 +410,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
`FixedLenFeature`.
|
||||
reader_cls: a subclass of `tensorflow.ReaderBase` that will be used to
|
||||
read the `Example`s.
|
||||
num_epochs: the number of times that the reader should loop through all
|
||||
the file names. If set to `None`, then the reader will continue
|
||||
indefinitely.
|
||||
num_threads: the number of readers that will work in parallel.
|
||||
enqueue_size: block size for each read operation.
|
||||
batch_size: desired batch size.
|
||||
@ -454,7 +440,6 @@ class TensorFlowDataFrame(df.DataFrame):
|
||||
filenames,
|
||||
enqueue_size=enqueue_size,
|
||||
batch_size=batch_size,
|
||||
num_epochs=num_epochs,
|
||||
queue_capacity=queue_capacity,
|
||||
shuffle=shuffle,
|
||||
min_after_dequeue=min_after_dequeue,
|
||||
|
@ -223,13 +223,14 @@ class Transform(object):
|
||||
# pylint: disable=not-callable
|
||||
return self.return_type(*output_series)
|
||||
|
||||
def build_transitive(self, input_series, cache=None):
|
||||
def build_transitive(self, input_series, cache=None, **kwargs):
|
||||
"""Apply this `Transform` to the provided `Series`, producing 'Tensor's.
|
||||
|
||||
Args:
|
||||
input_series: None, a `Series`, or a list of input `Series`, acting as
|
||||
positional arguments.
|
||||
cache: a dict from Series reprs to Tensors.
|
||||
**kwargs: Additional keyword arguments, unused here.
|
||||
|
||||
Returns:
|
||||
A namedtuple of the output Tensors.
|
||||
@ -244,7 +245,7 @@ class Transform(object):
|
||||
if len(input_series) != self.input_valency:
|
||||
raise ValueError("Expected %s input Series but received %s." %
|
||||
(self.input_valency, len(input_series)))
|
||||
input_tensors = [series.build(cache) for series in input_series]
|
||||
input_tensors = [series.build(cache, **kwargs) for series in input_series]
|
||||
|
||||
# Note we cache each output individually, not just the entire output
|
||||
# tuple. This allows using the graph as the cache, since it can sensibly
|
||||
@ -254,7 +255,7 @@ class Transform(object):
|
||||
output_tensors = [cache.get(output_repr) for output_repr in output_reprs]
|
||||
|
||||
if None in output_tensors:
|
||||
result = self._apply_transform(input_tensors)
|
||||
result = self._apply_transform(input_tensors, **kwargs)
|
||||
for output_name, output_repr in zip(self.output_names, output_reprs):
|
||||
cache[output_repr] = getattr(result, output_name)
|
||||
else:
|
||||
@ -264,12 +265,13 @@ class Transform(object):
|
||||
return result
|
||||
|
||||
@abstractmethod
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
"""Applies the transformation to the `transform_input`.
|
||||
|
||||
Args:
|
||||
input_tensors: a list of Tensors representing the input to
|
||||
the Transform.
|
||||
**kwargs: Additional keyword arguments, unused here.
|
||||
|
||||
Returns:
|
||||
A namedtuple of Tensors representing the transformed output.
|
||||
|
@ -72,7 +72,7 @@ class Batch(AbstractBatchTransform):
|
||||
def name(self):
|
||||
return "Batch"
|
||||
|
||||
def _apply_transform(self, transform_input):
|
||||
def _apply_transform(self, transform_input, **kwargs):
|
||||
batched = input_ops.batch(transform_input,
|
||||
batch_size=self.batch_size,
|
||||
num_threads=self.num_threads,
|
||||
@ -121,7 +121,7 @@ class ShuffleBatch(AbstractBatchTransform):
|
||||
def seed(self):
|
||||
return self._seed
|
||||
|
||||
def _apply_transform(self, transform_input):
|
||||
def _apply_transform(self, transform_input, **kwargs):
|
||||
batched = input_ops.shuffle_batch(transform_input,
|
||||
batch_size=self.batch_size,
|
||||
capacity=self.queue_capacity,
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -53,7 +53,7 @@ class SeriesBinaryTransform(transform.Transform):
|
||||
def _output_names(self):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
# TODO(jamieas): consider supporting sparse inputs.
|
||||
if isinstance(input_tensors[0], ops.SparseTensor) or isinstance(
|
||||
input_tensors[1], ops.SparseTensor):
|
||||
@ -87,7 +87,7 @@ class ScalarBinaryTransform(transform.Transform):
|
||||
def _output_names(self):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
input_tensor = input_tensors[0]
|
||||
if isinstance(input_tensor, ops.SparseTensor):
|
||||
result = ops.SparseTensor(input_tensor.indices,
|
||||
|
@ -77,12 +77,13 @@ class BooleanMask(transform.Transform):
|
||||
def _output_names(self):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
"""Applies the transformation to the `transform_input`.
|
||||
|
||||
Args:
|
||||
input_tensors: a list of Tensors representing the input to
|
||||
the Transform.
|
||||
**kwargs: Additional keyword arguments, unused here.
|
||||
|
||||
Returns:
|
||||
A namedtuple of Tensors representing the transformed output.
|
||||
|
@ -58,7 +58,7 @@ class CSVParser(transform.Transform):
|
||||
def default_values(self):
|
||||
return self._default_values
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
default_consts = [constant_op.constant(d, shape=[1])
|
||||
for d in self._default_values]
|
||||
parsed_values = parsing_ops.decode_csv(input_tensors[0],
|
||||
|
@ -47,12 +47,13 @@ class Densify(transform.Transform):
|
||||
def _output_names(self):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
"""Applies the transformation to the `transform_input`.
|
||||
|
||||
Args:
|
||||
input_tensors: a list of Tensors representing the input to
|
||||
the Transform.
|
||||
**kwargs: Additional keyword arguments, unused here.
|
||||
|
||||
Returns:
|
||||
A namedtuple of Tensors representing the transformed output.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -50,7 +50,7 @@ class Difference(transform.Transform):
|
||||
def _output_names(self):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
||||
isinstance(input_tensors[1], ops.SparseTensor))
|
||||
|
||||
|
@ -61,7 +61,7 @@ class ExampleParser(transform.Transform):
|
||||
def feature_definitions(self):
|
||||
return self._ordered_features
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
parsed_values = parsing_ops.parse_example(input_tensors[0],
|
||||
features=self._ordered_features)
|
||||
# pylint: disable=not-callable
|
||||
|
@ -89,7 +89,7 @@ class BaseInMemorySource(transform.Transform):
|
||||
def input_valency(self):
|
||||
return 0
|
||||
|
||||
def _apply_transform(self, transform_input):
|
||||
def _apply_transform(self, transform_input, **kwargs):
|
||||
queue = feeding_functions.enqueue_data(self.data,
|
||||
self.queue_capacity,
|
||||
self.shuffle,
|
||||
|
@ -32,7 +32,6 @@ class ReaderSource(transform.Transform):
|
||||
reader_kwargs=None,
|
||||
enqueue_size=None,
|
||||
batch_size=1,
|
||||
num_epochs=None,
|
||||
queue_capacity=None,
|
||||
shuffle=False,
|
||||
min_after_dequeue=None,
|
||||
@ -49,9 +48,6 @@ class ReaderSource(transform.Transform):
|
||||
is constructed.
|
||||
enqueue_size: block size for each read operation.
|
||||
batch_size: The desired batch size of output. Defaults to 1.
|
||||
num_epochs: the number of times that the reader should loop through all
|
||||
the file names. If set to `None`, then the reader will continue
|
||||
indefinitely.
|
||||
queue_capacity: Capacity of the queue. Defaults to 10 * `batch_size`.
|
||||
shuffle: Whether records will be shuffled before returning. Defaults to
|
||||
false.
|
||||
@ -73,7 +69,6 @@ class ReaderSource(transform.Transform):
|
||||
self._batch_size = batch_size
|
||||
self._queue_capacity = (batch_size * 10 if queue_capacity is None else
|
||||
queue_capacity)
|
||||
self._num_epochs = num_epochs
|
||||
self._shuffle = shuffle
|
||||
self._min_after_dequeue = int(self.queue_capacity / 4 if min_after_dequeue
|
||||
is None else min_after_dequeue)
|
||||
@ -100,10 +95,6 @@ class ReaderSource(transform.Transform):
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@transform.parameter
|
||||
def num_epochs(self):
|
||||
return self._num_epochs
|
||||
|
||||
@transform.parameter
|
||||
def queue_capacity(self):
|
||||
return self._queue_capacity
|
||||
@ -136,9 +127,10 @@ class ReaderSource(transform.Transform):
|
||||
def _output_names(self):
|
||||
return ("index", "value")
|
||||
|
||||
def _apply_transform(self, transform_input):
|
||||
filename_queue = input_ops.string_input_producer(self.work_units,
|
||||
num_epochs=self.num_epochs,
|
||||
def _apply_transform(self, transform_input, **kwargs):
|
||||
filename_queue = input_ops.string_input_producer(
|
||||
self.work_units,
|
||||
num_epochs=kwargs.get("num_epochs"),
|
||||
shuffle=self.shuffle,
|
||||
seed=self.seed)
|
||||
reader_ops = []
|
||||
@ -174,7 +166,6 @@ def TextFileSource(file_names,
|
||||
reader_kwargs=None,
|
||||
enqueue_size=1,
|
||||
batch_size=1,
|
||||
num_epochs=None,
|
||||
queue_capacity=None,
|
||||
shuffle=False,
|
||||
min_after_dequeue=None,
|
||||
@ -185,7 +176,6 @@ def TextFileSource(file_names,
|
||||
reader_kwargs=reader_kwargs,
|
||||
enqueue_size=enqueue_size,
|
||||
batch_size=batch_size,
|
||||
num_epochs=num_epochs,
|
||||
queue_capacity=queue_capacity,
|
||||
shuffle=shuffle,
|
||||
min_after_dequeue=min_after_dequeue,
|
||||
@ -197,7 +187,6 @@ def TFRecordSource(file_names,
|
||||
reader_kwargs=None,
|
||||
enqueue_size=1,
|
||||
batch_size=1,
|
||||
num_epochs=None,
|
||||
queue_capacity=None,
|
||||
shuffle=False,
|
||||
min_after_dequeue=None,
|
||||
@ -208,7 +197,6 @@ def TFRecordSource(file_names,
|
||||
reader_kwargs=reader_kwargs,
|
||||
enqueue_size=enqueue_size,
|
||||
batch_size=batch_size,
|
||||
num_epochs=num_epochs,
|
||||
queue_capacity=queue_capacity,
|
||||
shuffle=shuffle,
|
||||
min_after_dequeue=min_after_dequeue,
|
||||
|
@ -52,12 +52,13 @@ class Sparsify(transform.Transform):
|
||||
def _output_names(self):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
"""Applies the transformation to the `transform_input`.
|
||||
|
||||
Args:
|
||||
input_tensors: a list of Tensors representing the input to
|
||||
the Transform.
|
||||
**kwargs: Additional keyword arguments, unused here.
|
||||
|
||||
Returns:
|
||||
A namedtuple of Tensors representing the transformed output.
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -44,7 +44,7 @@ class Sum(transform.Transform):
|
||||
def _output_names(self):
|
||||
return "output",
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
pair_sparsity = (isinstance(input_tensors[0], ops.SparseTensor),
|
||||
isinstance(input_tensors[1], ops.SparseTensor))
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -78,7 +78,7 @@ def register_unary_op(registered_name, operation):
|
||||
def _output_names(self):
|
||||
return "output"
|
||||
|
||||
def _apply_transform(self, input_tensors):
|
||||
def _apply_transform(self, input_tensors, **kwargs):
|
||||
input_tensor = input_tensors[0]
|
||||
if isinstance(input_tensor, ops.SparseTensor):
|
||||
result = ops.SparseTensor(input_tensor.indices,
|
||||
|
@ -208,10 +208,9 @@ class TensorFlowDataFrameTestCase(tf.test.TestCase):
|
||||
tensorflow_df = df.TensorFlowDataFrame.from_csv(
|
||||
[data_path],
|
||||
batch_size=batch_size,
|
||||
num_epochs=num_epochs,
|
||||
shuffle=False,
|
||||
default_values=default_values)
|
||||
actual_num_batches = len(list(tensorflow_df.run()))
|
||||
actual_num_batches = len(list(tensorflow_df.run(num_epochs=num_epochs)))
|
||||
self.assertEqual(expected_num_batches, actual_num_batches)
|
||||
|
||||
def testFromCSVWithFeatureSpec(self):
|
||||
|
Loading…
Reference in New Issue
Block a user