Defer conversion of frame_step to tensor in spectral_ops inverse_stft_window_fn.

PiperOrigin-RevId: 346181742
Change-Id: Ieab49867b1cada05cdb7244f1f9f2482dd06e718
This commit is contained in:
A. Unique TensorFlower 2020-12-07 14:33:16 -08:00 committed by TensorFlower Gardener
parent 19fa254274
commit 0a15fbc048

View File

@ -120,10 +120,6 @@ def inverse_stft_window_fn(frame_step,
The returned window is suitable for reconstructing original waveform in
inverse_stft.
"""
with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
frame_step = ops.convert_to_tensor(frame_step, name='frame_step')
frame_step.shape.assert_has_rank(0)
def inverse_stft_window_fn_inner(frame_length, dtype):
"""Computes a window that can be used in `inverse_stft`.
@ -141,18 +137,20 @@ def inverse_stft_window_fn(frame_step,
`frame_step` is not scalar, or `frame_step` is not scalar.
"""
with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]):
frame_step_ = ops.convert_to_tensor(frame_step, name='frame_step')
frame_step_.shape.assert_has_rank(0)
frame_length = ops.convert_to_tensor(frame_length, name='frame_length')
frame_length.shape.assert_has_rank(0)
# Use equation 7 from Griffin + Lim.
forward_window = forward_window_fn(frame_length, dtype=dtype)
denom = math_ops.square(forward_window)
overlaps = -(-frame_length // frame_step) # Ceiling division.
denom = array_ops.pad(denom, [(0, overlaps * frame_step - frame_length)])
denom = array_ops.reshape(denom, [overlaps, frame_step])
overlaps = -(-frame_length // frame_step_) # Ceiling division.
denom = array_ops.pad(denom, [(0, overlaps * frame_step_ - frame_length)])
denom = array_ops.reshape(denom, [overlaps, frame_step_])
denom = math_ops.reduce_sum(denom, 0, keepdims=True)
denom = array_ops.tile(denom, [overlaps, 1])
denom = array_ops.reshape(denom, [overlaps * frame_step])
denom = array_ops.reshape(denom, [overlaps * frame_step_])
return forward_window / denom[:frame_length]
return inverse_stft_window_fn_inner