Quick return with simpler reshape op if the frame hop and frame length are statically known and their values are equal (i.e. no overlap). #tf-signal
PiperOrigin-RevId: 208002136
This commit is contained in:
parent
4ff625bc58
commit
f2494ede82
@ -90,22 +90,28 @@ def overlap_and_add(signal, frame_step, name=None):
|
|||||||
raise ValueError("frame_step must be an integer. Got %s" %
|
raise ValueError("frame_step must be an integer. Got %s" %
|
||||||
frame_step.dtype)
|
frame_step.dtype)
|
||||||
|
|
||||||
# If frame_length and frame_step are known at graph construction time, check
|
|
||||||
# frame_step is less than or equal to frame_length.
|
|
||||||
frame_step_static = tensor_util.constant_value(frame_step)
|
|
||||||
if (frame_step_static is not None and signal.shape.ndims is not None and
|
|
||||||
signal.shape[-1].value is not None and
|
|
||||||
frame_step_static > signal.shape[-1].value):
|
|
||||||
raise ValueError(
|
|
||||||
"frame_step (%d) must be less than or equal to frame_length (%d)" % (
|
|
||||||
frame_step_static, signal.shape[-1].value))
|
|
||||||
|
|
||||||
signal_shape = array_ops.shape(signal)
|
signal_shape = array_ops.shape(signal)
|
||||||
|
|
||||||
# All dimensions that are not part of the overlap-and-add. Can be empty for
|
# All dimensions that are not part of the overlap-and-add. Can be empty for
|
||||||
# rank 2 inputs.
|
# rank 2 inputs.
|
||||||
outer_dimensions = signal_shape[:-2]
|
outer_dimensions = signal_shape[:-2]
|
||||||
|
|
||||||
|
# If frame_length and frame_step are known at graph construction time, check
|
||||||
|
# frame_step is less than or equal to frame_length.
|
||||||
|
frame_step_static = tensor_util.constant_value(frame_step)
|
||||||
|
if (frame_step_static is not None and signal.shape.ndims is not None and
|
||||||
|
signal.shape[-1].value is not None):
|
||||||
|
if frame_step_static > signal.shape[-1].value:
|
||||||
|
raise ValueError(
|
||||||
|
"frame_step (%d) must be less than or equal to "
|
||||||
|
"frame_length (%d)" % (
|
||||||
|
frame_step_static, signal.shape[-1].value))
|
||||||
|
# If frame_length is equal to frame_step, there's no overlap so just
|
||||||
|
# reshape the tensor.
|
||||||
|
if frame_step_static == signal.shape[-1].value:
|
||||||
|
return array_ops.reshape(signal, array_ops.concat(
|
||||||
|
[outer_dimensions, [-1]], 0))
|
||||||
|
|
||||||
signal_rank = array_ops.rank(signal)
|
signal_rank = array_ops.rank(signal)
|
||||||
frames = signal_shape[-2]
|
frames = signal_shape[-2]
|
||||||
frame_length = signal_shape[-1]
|
frame_length = signal_shape[-1]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user