Quick return with simpler reshape op if the frame hop and frame length are statically known and their values are equal (i.e. no overlap). #tf-signal

PiperOrigin-RevId: 208002136
This commit is contained in:
A. Unique TensorFlower 2018-08-08 23:50:57 -07:00 committed by TensorFlower Gardener
parent 4ff625bc58
commit f2494ede82

View File

@ -90,22 +90,28 @@ def overlap_and_add(signal, frame_step, name=None):
raise ValueError("frame_step must be an integer. Got %s" %
frame_step.dtype)
# If frame_length and frame_step are known at graph construction time, check
# frame_step is less than or equal to frame_length.
frame_step_static = tensor_util.constant_value(frame_step)
if (frame_step_static is not None and signal.shape.ndims is not None and
signal.shape[-1].value is not None and
frame_step_static > signal.shape[-1].value):
raise ValueError(
"frame_step (%d) must be less than or equal to frame_length (%d)" % (
frame_step_static, signal.shape[-1].value))
signal_shape = array_ops.shape(signal)
# All dimensions that are not part of the overlap-and-add. Can be empty for
# rank 2 inputs.
outer_dimensions = signal_shape[:-2]
# If frame_length and frame_step are known at graph construction time, check
# frame_step is less than or equal to frame_length.
frame_step_static = tensor_util.constant_value(frame_step)
if (frame_step_static is not None and signal.shape.ndims is not None and
signal.shape[-1].value is not None):
if frame_step_static > signal.shape[-1].value:
raise ValueError(
"frame_step (%d) must be less than or equal to "
"frame_length (%d)" % (
frame_step_static, signal.shape[-1].value))
# If frame_length is equal to frame_step, there's no overlap so just
# reshape the tensor.
if frame_step_static == signal.shape[-1].value:
return array_ops.reshape(signal, array_ops.concat(
[outer_dimensions, [-1]], 0))
signal_rank = array_ops.rank(signal)
frames = signal_shape[-2]
frame_length = signal_shape[-1]