Defer conversion of frame_step to tensor in spectral_ops inverse_stft_window_fn.
PiperOrigin-RevId: 346181742 Change-Id: Ieab49867b1cada05cdb7244f1f9f2482dd06e718
This commit is contained in:
parent
19fa254274
commit
0a15fbc048
@ -120,10 +120,6 @@ def inverse_stft_window_fn(frame_step,
|
|||||||
The returned window is suitable for reconstructing original waveform in
|
The returned window is suitable for reconstructing original waveform in
|
||||||
inverse_stft.
|
inverse_stft.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
|
|
||||||
frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
|
|
||||||
frame_step.shape.assert_has_rank(0)
|
|
||||||
|
|
||||||
def inverse_stft_window_fn_inner(frame_length, dtype):
|
def inverse_stft_window_fn_inner(frame_length, dtype):
|
||||||
"""Computes a window that can be used in `inverse_stft`.
|
"""Computes a window that can be used in `inverse_stft`.
|
||||||
|
|
||||||
@ -141,18 +137,20 @@ def inverse_stft_window_fn(frame_step,
|
|||||||
`frame_step` is not scalar, or `frame_step` is not scalar.
|
`frame_step` is not scalar, or `frame_step` is not scalar.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
|
with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
|
||||||
|
frame_step_ = ops.convert_to_tensor(frame_step, name='frame_step')
|
||||||
|
frame_step_.shape.assert_has_rank(0)
|
||||||
frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
|
frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
|
||||||
frame_length.shape.assert_has_rank(0)
|
frame_length.shape.assert_has_rank(0)
|
||||||
|
|
||||||
# Use equation 7 from Griffin + Lim.
|
# Use equation 7 from Griffin + Lim.
|
||||||
forward_window = forward_window_fn(frame_length, dtype=dtype)
|
forward_window = forward_window_fn(frame_length, dtype=dtype)
|
||||||
denom = math_ops.square(forward_window)
|
denom = math_ops.square(forward_window)
|
||||||
overlaps = -(-frame_length // frame_step) # Ceiling division.
|
overlaps = -(-frame_length // frame_step_) # Ceiling division.
|
||||||
denom = array_ops.pad(denom, [(0, overlaps * frame_step - frame_length)])
|
denom = array_ops.pad(denom, [(0, overlaps * frame_step_ - frame_length)])
|
||||||
denom = array_ops.reshape(denom, [overlaps, frame_step])
|
denom = array_ops.reshape(denom, [overlaps, frame_step_])
|
||||||
denom = math_ops.reduce_sum(denom, 0, keepdims=True)
|
denom = math_ops.reduce_sum(denom, 0, keepdims=True)
|
||||||
denom = array_ops.tile(denom, [overlaps, 1])
|
denom = array_ops.tile(denom, [overlaps, 1])
|
||||||
denom = array_ops.reshape(denom, [overlaps * frame_step])
|
denom = array_ops.reshape(denom, [overlaps * frame_step_])
|
||||||
|
|
||||||
return forward_window / denom[:frame_length]
|
return forward_window / denom[:frame_length]
|
||||||
return inverse_stft_window_fn_inner
|
return inverse_stft_window_fn_inner
|
||||||
|
Loading…
x
Reference in New Issue
Block a user