[tf.data] Convert dataset arguments to tensors as early as possible.
This change raises a `TypeError` earlier if (for example) the `batch_size` argument to `Dataset.batch()` has the incorrect type. PiperOrigin-RevId: 173126678
This commit is contained in:
parent
4f7503a876
commit
fc56349b7f
@ -1057,21 +1057,21 @@ class RangeDataset(Dataset):
|
||||
def _parse_args(self, *args):
|
||||
if len(args) == 1:
|
||||
self._start = self._build_tensor(0, "start")
|
||||
self._stop = args[0]
|
||||
self._stop = self._build_tensor(args[0], "stop")
|
||||
self._step = self._build_tensor(1, "step")
|
||||
elif len(args) == 2:
|
||||
self._start = args[0]
|
||||
self._stop = args[1]
|
||||
self._start = self._build_tensor(args[0], "start")
|
||||
self._stop = self._build_tensor(args[1], "stop")
|
||||
self._step = self._build_tensor(1, "step")
|
||||
elif len(args) == 3:
|
||||
self._start = args[0]
|
||||
self._stop = args[1]
|
||||
self._step = args[2]
|
||||
self._start = self._build_tensor(args[0], "start")
|
||||
self._stop = self._build_tensor(args[1], "stop")
|
||||
self._step = self._build_tensor(args[2], "step")
|
||||
else:
|
||||
raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))
|
||||
|
||||
def _build_tensor(self, int64_value, name):
|
||||
return constant_op.constant(int64_value, dtype=dtypes.int64, name=name)
|
||||
return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.range_dataset(
|
||||
@ -1217,7 +1217,8 @@ class BatchDataset(Dataset):
|
||||
"""See `Dataset.batch()` for details."""
|
||||
super(BatchDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
self._batch_size = batch_size
|
||||
self._batch_size = ops.convert_to_tensor(batch_size, dtype=dtypes.int64,
|
||||
name="batch_size")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.batch_dataset(
|
||||
@ -1285,7 +1286,8 @@ class PaddedBatchDataset(Dataset):
|
||||
"""See `Dataset.batch()` for details."""
|
||||
super(PaddedBatchDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
self._batch_size = batch_size
|
||||
self._batch_size = ops.convert_to_tensor(batch_size, dtype=dtypes.int64,
|
||||
name="batch_size")
|
||||
padding_values = (padding_values if padding_values is not None else
|
||||
self._default_padding(input_dataset))
|
||||
self._padded_shapes = nest.map_structure_up_to(
|
||||
@ -1509,8 +1511,10 @@ class InterleaveDataset(Dataset):
|
||||
self._map_func = tf_map_func
|
||||
self._map_func.add_to_graph(ops.get_default_graph())
|
||||
|
||||
self._cycle_length = ops.convert_to_tensor(cycle_length, dtype=dtypes.int64)
|
||||
self._block_length = ops.convert_to_tensor(block_length, dtype=dtypes.int64)
|
||||
self._cycle_length = ops.convert_to_tensor(cycle_length, dtype=dtypes.int64,
|
||||
name="cycle_length")
|
||||
self._block_length = ops.convert_to_tensor(block_length, dtype=dtypes.int64,
|
||||
name="block_length")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.interleave_dataset(
|
||||
@ -1587,7 +1591,8 @@ class PrefetchDataset(Dataset):
|
||||
"""See `Dataset.prefetch()` for details."""
|
||||
super(PrefetchDataset, self).__init__()
|
||||
self._input_dataset = input_dataset
|
||||
self._buffer_size = ops.convert_to_tensor(buffer_size, dtype=dtypes.int64)
|
||||
self._buffer_size = ops.convert_to_tensor(buffer_size, dtype=dtypes.int64,
|
||||
name="buffer_size")
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
return gen_dataset_ops.prefetch_dataset(
|
||||
|
Loading…
Reference in New Issue
Block a user