time_frequency

Time-frequency Keras layers.

This module has low-level implementations of some popular time-frequency operations such as STFT and inverse STFT. We’re using these layers to compose layers in kapre.composed where more high-level and popular layers such as melspectrogram layer are provided. You should go check it out!

Note

Why time-frequency representation?

Every representation (STFT, melspectrogram, etc) has something in common - they’re all 2D representations (time, frequency-ish) of audio signals. They’re helpful because they decompose an audio signal, which is a simultaneous mixture of a lot of frequency components into different frequency bins. They have spatial property; the frequency bins are sorted, so frequency bins nearby has represent only slightly different frequency components. The frequency decomposition is also what’s happening during human auditory perception through cochlea.

Which representation to use as input?

For a quick summary, check out my tutorial paper, A Tutorial on Deep Learning for Music Information Retrieval.

class kapre.time_frequency.STFT(n_fft=2048, win_length=None, hop_length=None, window_name=None, pad_begin=False, pad_end=False, input_data_format='default', output_data_format='default', **kwargs)[source]

A Short-time Fourier transform layer.

It uses tf.signal.stft to compute complex STFT. Additionally, it reshapes the output to be a proper 2D batch.

If output_data_format == ‘channels_last’, the output shape is (batch, time, freq, channel) If output_data_format == ‘channels_first’, the output shape is (batch, channel, time, freq)

Parameters:
  • n_fft (int) – Number of FFTs. Defaults to 2048
  • win_length (int or None) – Window length in sample. Defaults to n_fft.
  • hop_length (int or None) – Hop length in sample between analysis windows. Defaults to win_length // 4 following Librosa.
  • window_name (str or None) – Name of tf.signal function that returns a 1D tensor window that is used in analysis. Defaults to hann_window which uses tf.signal.hann_window. Window availability depends on Tensorflow version. More details are at kapre.backend.get_window().
  • pad_begin (bool) – Whether to pad with zeros along time axis (length: win_length - hop_length). Defaults to False.
  • pad_end (bool) – Whether to pad with zeros at the finishing end of the signal.
  • input_data_format (str) – the audio data format of input waveform batch. ‘channels_last’ if it’s (batch, time, channels) and ‘channels_first’ if it’s (batch, channels, time). Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
  • output_data_format (str) – The data format of output STFT. ‘channels_last’ if you want (batch, time, frequency, channels) and ‘channels_first’ if you want (batch, channels, time, frequency) Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
  • **kwargs – Keyword args for the parent keras layer (e.g., name)

Example

input_shape = (2048, 1)  # mono signal
model = Sequential()
model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape))
# now the shape is (batch, n_frame=3, n_freq=513, ch=1)
# and the dtype is complex
call(x)[source]

Compute STFT of the input signal. If the time axis is not the last axis of x, it should be transposed first.

Parameters:x (float Tensor) – batch of audio signals, (batch, ch, time) or (batch, time, ch) based on input_data_format
Returns:A STFT representation of x in a 2D batch shape. complex64 if x is float32, complex128 if x is float64. Its shape is (batch, time, freq, ch) or (batch. ch, time, freq) depending on output_data_format and time is the number of frames, which is ((len_src + (win_length - hop_length) / hop_length) // win_length ) if pad_end is True. freq is the number of fft unique bins, which is n_fft // 2 + 1 (the unique components of the FFT).
Return type:(complex Tensor)
class kapre.time_frequency.InverseSTFT(n_fft=2048, win_length=None, hop_length=None, forward_window_name=None, input_data_format='default', output_data_format='default', **kwargs)[source]

An inverse-STFT layer.

If output_data_format == ‘channels_last’, the output shape is (batch, time, channel) If output_data_format == ‘channels_first’, the output shape is (batch, channel, time)

Note that the result of inverse STFT could be longer than the original signal due to the padding. Do check the size of the result by yourself and trim it if needed.

Parameters:
  • n_fft (int) – Number of FFTs. Defaults to 2048
  • win_length (int or None) – Window length in sample. Defaults to n_fft.
  • hop_length (int or None) – Hop length in sample between analysis windows. Defaults to n_fft // 4 following Librosa.
  • forward_window_name (str or None) – Name of tf.signal function that was used in the forward STFT. Defaults to hann_window, assuming tf.signal.hann_window was used. Window availability depends on Tensorflow version. More details are at kapre.backend.get_window().
  • input_data_format (str) – the data format of input STFT batch ‘channels_last’ if you want (batch, time, frequency, channels) ‘channels_first’ if you want (batch, channels, time, frequency) Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
  • output_data_format (str) – the audio data format of output waveform batch. ‘channels_last’ if it’s (batch, time, channels) ‘channels_first’ if it’s (batch, channels, time) Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format())
  • **kwargs – Keyword args for the parent keras layer (e.g., name)

Example

input_shape = (3, 513, 1)  # 3 frames, 513 frequency bins, 1 channel
# and input dtype is complex
model = Sequential()
model.add(kapre.InverseSTFT(n_fft=1024, hop_length=512, input_shape=input_shape))
# now the shape is (batch, time=2048, ch=1)
call(x)[source]

Compute inverse STFT of the input STFT.

Parameters:x (complex Tensor) – batch of STFTs, (batch, ch, time, freq) or (batch, time, freq, ch) depending on input_data_format
Returns:audio signals of x. Shape: 1D batch shape. I.e., (batch, time, ch) or (batch, ch, time) depending on output_data_format
Return type:(float)
class kapre.time_frequency.Magnitude(*args, **kwargs)[source]

Compute the magnitude of the complex input, resulting in a float tensor

Example

input_shape = (2048, 1)  # mono signal
model = Sequential()
model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape))
mode.add(Magnitude())
# now the shape is (batch, n_frame=3, n_freq=513, ch=1) and dtype is float
call(x)[source]
Parameters:x (complex Tensor) – input complex tensor
Returns:magnitude of x
Return type:(float Tensor)
class kapre.time_frequency.Phase(approx_atan_accuracy=None, **kwargs)[source]

Compute the phase of the complex input in radian, resulting in a float tensor

Includes option to use approximate phase algorithm this will return the same results as the PhaseTflite layer (the tflite compatible layer).

Parameters:approx_atan_accuracy (int) – if None will use tf.math.angle() to calculate the phase accurately. If an int this is the number of iterations to calculate the approximate atan() using a tflite compatible method. the higher the number the more accurate e.g. approx_atan_accuracy=29000. You may want to experiment with adjusting this number: trading off accuracy with inference speed.

Example

input_shape = (2048, 1)  # mono signal
model = Sequential()
model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape))
model.add(Phase())
# now the shape is (batch, n_frame=3, n_freq=513, ch=1) and dtype is float
call(x)[source]
Parameters:x (complex Tensor) – input complex tensor
Returns:phase of x (Radian)
Return type:(float Tensor)
class kapre.time_frequency.MagnitudeToDecibel(ref_value=1.0, amin=1e-05, dynamic_range=80.0, **kwargs)[source]

A class that wraps backend.magnitude_to_decibel to compute decibel of the input magnitude.

Parameters:
  • ref_value (float) – an input value that would become 0 dB in the result. For spectrogram magnitudes, ref_value=1.0 usually make the decibel-scaled output to be around zero if the input audio was in [-1, 1].
  • amin (float) – the noise floor of the input. An input that is smaller than amin, it’s converted to `amin.
  • dynamic_range (float) – range of the resulting value. E.g., if the maximum magnitude is 30 dB, the noise floor of the output would become (30 - dynamic_range) dB

Example

input_shape = (2048, 1)  # mono signal
model = Sequential()
model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape))
model.add(Magnitude())
model.add(MagnitudeToDecibel())
# now the shape is (batch, n_frame=3, n_freq=513, ch=1) and dtype is float
call(x)[source]
Parameters:x (Tensor) – float tensor. Can be batch or not. Something like magnitude of STFT.
Returns:decibel-scaled float tensor of x.
Return type:(Tensor)
class kapre.time_frequency.ApplyFilterbank(type, filterbank_kwargs, data_format='default', **kwargs)[source]

Apply a filterbank to the input spectrograms.

Parameters:
  • filterbank (Tensor) – filterbank tensor in a shape of (n_freq, n_filterbanks)
  • data_format (str) – specifies the data format of batch input/output
  • **kwargs – Keyword args for the parent keras layer (e.g., name)

Example

input_shape = (2048, 1)  # mono signal
n_fft = 1024
n_hop = n_fft // 2
kwargs = {
    'sample_rate': 22050,
    'n_freq': n_fft // 2 + 1,
    'n_mels': 128,
    'f_min': 0.0,
    'f_max': 8000,
}
model = Sequential()
model.add(kapre.STFT(n_fft=n_fft, hop_length=n_hop, input_shape=input_shape))
model.add(Magnitude())
# (batch, n_frame=3, n_freq=n_fft // 2 + 1, ch=1) and dtype is float
model.add(ApplyFilterbank(type='mel', filterbank_kwargs=kwargs))
# (batch, n_frame=3, n_mels=128, ch=1)
call(x)[source]

Apply filterbank to x.

Parameters:x (Tensor) – float tensor in 2D batch shape.
class kapre.time_frequency.Delta(win_length=5, mode='symmetric', data_format='default', **kwargs)[source]

Calculates delta, a local estimate of the derivative along time axis. See torchaudio.functional.compute_deltas or librosa.feature.delta for more details.

Parameters:
  • win_length (int) – Window length of the derivative estimation. Defaults to 5
  • mode (str) – Specifies pad mode of tf.pad. Case-insensitive. Defaults to ‘symmetric’. Can be ‘symmetric’, ‘reflect’, ‘constant’, or whatever tf.pad supports.

Example

input_shape = (2048, 1)  # mono signal
model = Sequential()
model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape))
model.add(kapre.Magnitude())
model.add(Delta())
# (batch, n_frame=3, n_freq=513, ch=1) and dtype is float
call(x)[source]
Parameters:x (Tensor) – a 2d batch (b, t, f, ch) or (b, ch, t, f)
Returns:A tensor with the same shape as input data.
Return type:(Tensor)
class kapre.time_frequency.ConcatenateFrequencyMap(data_format='default', **kwargs)[source]

Addes a frequency information channel to spectrograms.

The added frequency channel (=frequency map) has a linearly increasing values from 0.0 to 1.0, indicating the normalize frequency of a time-frequency bin. This layer can be applied to input audio spectrograms or any feature maps so that the following layers can be conditioned on the frequency. (Imagine something like positional encoding in NLP but the position is on frequency axis).

A combination of ConcatenateFrequencyMap and Conv2D is known as frequency-aware convolution (see References). For your convenience, such a layer is supported by karep.composed.get_frequency_aware_conv2d().

Parameters:
  • data_format (str) – specifies the data format of batch input/output.
  • **kwargs – Keyword args for the parent keras layer (e.g., name)

Example

input_shape = (2048, 1)  # mono signal
model = Sequential()
model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape))
model.add(kapre.Magnitude())
# (batch, n_frame=3, n_freq=513, ch=1) and dtype is float
model.add(kapre.ConcatenateFrequencyMap())
# (batch, n_frame=3, n_freq=513, ch=2)
# now add your model
mode.add(keras.layers.Conv2D(16, (3, 3), strides=(2, 2), activation='relu')
# you can concatenate frequency map before other conv layers,
# but probably, you wouldn't want to add it right before batch normalization.
model.add(kapre.ConcatenateFrequencyMap())
model.add(keras.layers.Conv2D(32, (3, 3), strides=(1, 1), activation='relu')
model.add(keras.layers.MaxPooling2D((2, 2)))  # length of frequency axis doesn't matter

References

Koutini, K., Eghbal-zadeh, H., & Widmer, G. (2019). Receptive-Field-Regularized CNN Variants for Acoustic Scene Classification. In Proceedings of the Detection and Classification of Acoustic Scenes and Events 2019 Workshop (DCASE2019).

call(x)[source]
Parameters:x (Tensor) – a 2d batch (b, t, f, ch) or (b, ch, t, f)
Returns:a 2d batch (b, t, f, ch + 1) or (b, ch + 1, t, f)
Return type:x (Tensor)