[tf.data] Minor fix to remove unnecessary difference between the implementations of the batch and padded batch reducers.
PiperOrigin-RevId: 211706766
This commit is contained in:
parent
7288b3da07
commit
7e2577b098
@ -272,9 +272,9 @@ def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
|
|||||||
padding_value = 0
|
padding_value = 0
|
||||||
|
|
||||||
def batch_init_fn(_):
|
def batch_init_fn(_):
|
||||||
return array_ops.fill(
|
batch_shape = array_ops.concat(
|
||||||
array_ops.concat([np.array([0], dtype=np.int32), padded_shape], 0),
|
[np.array([0], dtype=np.int32), padded_shape], 0)
|
||||||
constant_op.constant(padding_value, dtype=dataset.output_types))
|
return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
|
||||||
|
|
||||||
def batch_reduce_fn(state, value):
|
def batch_reduce_fn(state, value):
|
||||||
return array_ops.concat([state, [value]], 0)
|
return array_ops.concat([state, [value]], 0)
|
||||||
|
Loading…
Reference in New Issue
Block a user