Quick return with simpler reshape op if the frame hop and frame length are statically known and their values are equal (i.e. no overlap). #tf-signal
PiperOrigin-RevId: 208002136
This commit is contained in:
parent
4ff625bc58
commit
f2494ede82
@ -90,22 +90,28 @@ def overlap_and_add(signal, frame_step, name=None):
|
||||
raise ValueError("frame_step must be an integer. Got %s" %
|
||||
frame_step.dtype)
|
||||
|
||||
# If frame_length and frame_step are known at graph construction time, check
|
||||
# frame_step is less than or equal to frame_length.
|
||||
frame_step_static = tensor_util.constant_value(frame_step)
|
||||
if (frame_step_static is not None and signal.shape.ndims is not None and
|
||||
signal.shape[-1].value is not None and
|
||||
frame_step_static > signal.shape[-1].value):
|
||||
raise ValueError(
|
||||
"frame_step (%d) must be less than or equal to frame_length (%d)" % (
|
||||
frame_step_static, signal.shape[-1].value))
|
||||
|
||||
signal_shape = array_ops.shape(signal)
|
||||
|
||||
# All dimensions that are not part of the overlap-and-add. Can be empty for
|
||||
# rank 2 inputs.
|
||||
outer_dimensions = signal_shape[:-2]
|
||||
|
||||
# If frame_length and frame_step are known at graph construction time, check
|
||||
# frame_step is less than or equal to frame_length.
|
||||
frame_step_static = tensor_util.constant_value(frame_step)
|
||||
if (frame_step_static is not None and signal.shape.ndims is not None and
|
||||
signal.shape[-1].value is not None):
|
||||
if frame_step_static > signal.shape[-1].value:
|
||||
raise ValueError(
|
||||
"frame_step (%d) must be less than or equal to "
|
||||
"frame_length (%d)" % (
|
||||
frame_step_static, signal.shape[-1].value))
|
||||
# If frame_length is equal to frame_step, there's no overlap so just
|
||||
# reshape the tensor.
|
||||
if frame_step_static == signal.shape[-1].value:
|
||||
return array_ops.reshape(signal, array_ops.concat(
|
||||
[outer_dimensions, [-1]], 0))
|
||||
|
||||
signal_rank = array_ops.rank(signal)
|
||||
frames = signal_shape[-2]
|
||||
frame_length = signal_shape[-1]
|
||||
|
Loading…
x
Reference in New Issue
Block a user