[tf.data] Convert dataset arguments to tensors as early as possible.

This change raises a `TypeError` earlier if (for example) the `batch_size`
argument to `Dataset.batch()` has the incorrect type.

PiperOrigin-RevId: 173126678
This commit is contained in:
Derek Murray 2017-10-23 09:34:30 -07:00 committed by TensorFlower Gardener
parent 4f7503a876
commit fc56349b7f

View File

@ -1057,21 +1057,21 @@ class RangeDataset(Dataset):
def _parse_args(self, *args):
if len(args) == 1:
self._start = self._build_tensor(0, "start")
self._stop = args[0]
self._stop = self._build_tensor(args[0], "stop")
self._step = self._build_tensor(1, "step")
elif len(args) == 2:
self._start = args[0]
self._stop = args[1]
self._start = self._build_tensor(args[0], "start")
self._stop = self._build_tensor(args[1], "stop")
self._step = self._build_tensor(1, "step")
elif len(args) == 3:
self._start = args[0]
self._stop = args[1]
self._step = args[2]
self._start = self._build_tensor(args[0], "start")
self._stop = self._build_tensor(args[1], "stop")
self._step = self._build_tensor(args[2], "step")
else:
raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))
def _build_tensor(self, int64_value, name):
return constant_op.constant(int64_value, dtype=dtypes.int64, name=name)
return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
def _as_variant_tensor(self):
return gen_dataset_ops.range_dataset(
@ -1217,7 +1217,8 @@ class BatchDataset(Dataset):
"""See `Dataset.batch()` for details."""
super(BatchDataset, self).__init__()
self._input_dataset = input_dataset
self._batch_size = batch_size
self._batch_size = ops.convert_to_tensor(batch_size, dtype=dtypes.int64,
name="batch_size")
def _as_variant_tensor(self):
return gen_dataset_ops.batch_dataset(
@ -1285,7 +1286,8 @@ class PaddedBatchDataset(Dataset):
"""See `Dataset.batch()` for details."""
super(PaddedBatchDataset, self).__init__()
self._input_dataset = input_dataset
self._batch_size = batch_size
self._batch_size = ops.convert_to_tensor(batch_size, dtype=dtypes.int64,
name="batch_size")
padding_values = (padding_values if padding_values is not None else
self._default_padding(input_dataset))
self._padded_shapes = nest.map_structure_up_to(
@ -1509,8 +1511,10 @@ class InterleaveDataset(Dataset):
self._map_func = tf_map_func
self._map_func.add_to_graph(ops.get_default_graph())
self._cycle_length = ops.convert_to_tensor(cycle_length, dtype=dtypes.int64)
self._block_length = ops.convert_to_tensor(block_length, dtype=dtypes.int64)
self._cycle_length = ops.convert_to_tensor(cycle_length, dtype=dtypes.int64,
name="cycle_length")
self._block_length = ops.convert_to_tensor(block_length, dtype=dtypes.int64,
name="block_length")
def _as_variant_tensor(self):
return gen_dataset_ops.interleave_dataset(
@ -1587,7 +1591,8 @@ class PrefetchDataset(Dataset):
"""See `Dataset.prefetch()` for details."""
super(PrefetchDataset, self).__init__()
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(buffer_size, dtype=dtypes.int64)
self._buffer_size = ops.convert_to_tensor(buffer_size, dtype=dtypes.int64,
name="buffer_size")
def _as_variant_tensor(self):
return gen_dataset_ops.prefetch_dataset(