Source code for kapre.time_frequency

"""Time-frequency Keras layers.

This module has low-level implementations of some popular time-frequency operations such as STFT and inverse STFT.
We're using these layers to compose layers in `kapre.composed` where more high-level and popular layers
such as melspectrogram layer are provided. You should go check it out!

Note:
    **Why time-frequency representation?**

    Every representation (STFT, melspectrogram, etc) has something in common - they're all 2D representations
    (time, frequency-ish) of audio signals. They're helpful because they decompose an audio signal, which is a simultaneous
    mixture of a lot of frequency components into different frequency bins. They have spatial property; the frequency
    bins are *sorted*, so frequency bins nearby has represent only slightly different frequency components. The
    frequency decomposition is also what's happening during human auditory perception through cochlea.

    **Which representation to use as input?**

    For a quick summary, check out my tutorial paper, `A Tutorial on Deep Learning for Music Information Retrieval <https://arxiv.org/abs/1709.04396>`_.

"""
import tensorflow as tf
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer

from . import backend
from .backend import _CH_FIRST_STR, _CH_LAST_STR, _CH_DEFAULT_STR
from .tflite_compatible_stft import atan2_tflite

__all__ = [
    'STFT',
    'InverseSTFT',
    'Magnitude',
    'Phase',
    'MagnitudeToDecibel',
    'ApplyFilterbank',
    'Delta',
    'ConcatenateFrequencyMap',
]


def _shape_spectrum_output(spectrums, data_format):
    """Shape batch spectrograms into the right format.

    Args:
        spectrums (`Tensor`): result of tf.signal.stft or similar, i.e., (..., time, freq).
        data_format (`str`): 'channels_first' or 'channels_last'

    Returns:
        spectrums (`Tensor`): a transposed version of input `spectrums`

    """
    if data_format == _CH_FIRST_STR:
        pass  # probably it's already (batch, channel, time, freq)
    else:
        spectrums = tf.transpose(spectrums, perm=(0, 2, 3, 1))  # (batch, time, freq, channel)
    return spectrums


[docs]class STFT(Layer): """ A Short-time Fourier transform layer. It uses `tf.signal.stft` to compute complex STFT. Additionally, it reshapes the output to be a proper 2D batch. If `output_data_format == 'channels_last'`, the output shape is (batch, time, freq, channel) If `output_data_format == 'channels_first'`, the output shape is (batch, channel, time, freq) Args: n_fft (int): Number of FFTs. Defaults to `2048` win_length (int or None): Window length in sample. Defaults to `n_fft`. hop_length (int or None): Hop length in sample between analysis windows. Defaults to `win_length // 4` following Librosa. window_name (str or None): *Name* of `tf.signal` function that returns a 1D tensor window that is used in analysis. Defaults to `hann_window` which uses `tf.signal.hann_window`. Window availability depends on Tensorflow version. More details are at `kapre.backend.get_window()`. pad_begin (bool): Whether to pad with zeros along time axis (length: win_length - hop_length). Defaults to `False`. pad_end (bool): Whether to pad with zeros at the finishing end of the signal. input_data_format (str): the audio data format of input waveform batch. `'channels_last'` if it's `(batch, time, channels)` and `'channels_first'` if it's `(batch, channels, time)`. Defaults to the setting of your Keras configuration. (`tf.keras.backend.image_data_format()`) output_data_format (str): The data format of output STFT. `'channels_last'` if you want `(batch, time, frequency, channels)` and `'channels_first'` if you want `(batch, channels, time, frequency)` Defaults to the setting of your Keras configuration. (`tf.keras.backend.image_data_format()`) **kwargs: Keyword args for the parent keras layer (e.g., `name`) Example: :: input_shape = (2048, 1) # mono signal model = Sequential() model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape)) # now the shape is (batch, n_frame=3, n_freq=513, ch=1) # and the dtype is complex """ def __init__( self, n_fft=2048, win_length=None, hop_length=None, window_name=None, pad_begin=False, pad_end=False, input_data_format='default', output_data_format='default', **kwargs, ): super(STFT, self).__init__(**kwargs) backend.validate_data_format_str(input_data_format) backend.validate_data_format_str(output_data_format) if win_length is None: win_length = n_fft if hop_length is None: hop_length = win_length // 4 self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length self.window_name = window_name self.window_fn = backend.get_window_fn(window_name) self.pad_begin = pad_begin self.pad_end = pad_end idt, odt = input_data_format, output_data_format self.output_data_format = K.image_data_format() if odt == _CH_DEFAULT_STR else odt self.input_data_format = K.image_data_format() if idt == _CH_DEFAULT_STR else idt
[docs] def call(self, x): """ Compute STFT of the input signal. If the `time` axis is not the last axis of `x`, it should be transposed first. Args: x (float `Tensor`): batch of audio signals, (batch, ch, time) or (batch, time, ch) based on input_data_format Return: (complex `Tensor`): A STFT representation of x in a 2D batch shape. `complex64` if `x` is `float32`, `complex128` if `x` is `float64`. Its shape is (batch, time, freq, ch) or (batch. ch, time, freq) depending on `output_data_format` and `time` is the number of frames, which is `((len_src + (win_length - hop_length) / hop_length) // win_length )` if `pad_end` is `True`. `freq` is the number of fft unique bins, which is `n_fft // 2 + 1` (the unique components of the FFT). """ waveforms = x # (batch, ch, time) if input_data_format == 'channels_first'. # (batch, time, ch) if input_data_format == 'channels_last'. # this is needed because tf.signal.stft lives in channels_first land. if self.input_data_format == _CH_LAST_STR: waveforms = tf.transpose( waveforms, perm=(0, 2, 1) ) # always (batch, ch, time) from here if self.pad_begin: waveforms = tf.pad( waveforms, tf.constant([[0, 0], [0, 0], [int(self.n_fft - self.hop_length), 0]]) ) stfts = tf.signal.stft( signals=waveforms, frame_length=self.win_length, frame_step=self.hop_length, fft_length=self.n_fft, window_fn=self.window_fn, pad_end=self.pad_end, name='%s_tf.signal.stft' % self.name, ) # (batch, ch, time, freq) if self.output_data_format == _CH_LAST_STR: stfts = tf.transpose(stfts, perm=(0, 2, 3, 1)) # (batch, t, f, ch) return stfts
def get_config(self): config = super(STFT, self).get_config() config.update( { 'n_fft': self.n_fft, 'win_length': self.win_length, 'hop_length': self.hop_length, 'window_name': self.window_name, 'pad_begin': self.pad_begin, 'pad_end': self.pad_end, 'input_data_format': self.input_data_format, 'output_data_format': self.output_data_format, } ) return config
[docs]class InverseSTFT(Layer): """An inverse-STFT layer. If `output_data_format == 'channels_last'`, the output shape is (batch, time, channel) If `output_data_format == 'channels_first'`, the output shape is (batch, channel, time) Note that the result of inverse STFT could be longer than the original signal due to the padding. Do check the size of the result by yourself and trim it if needed. Args: n_fft (int): Number of FFTs. Defaults to `2048` win_length (`int` or `None`): Window length in sample. Defaults to `n_fft`. hop_length (`int` or `None`): Hop length in sample between analysis windows. Defaults to `n_fft // 4` following Librosa. forward_window_name (str or None): *Name* of `tf.signal` function that *was* used in the forward STFT. Defaults to `hann_window`, assuming `tf.signal.hann_window` was used. Window availability depends on Tensorflow version. More details are at `kapre.backend.get_window()`. input_data_format (`str`): the data format of input STFT batch `'channels_last'` if you want `(batch, time, frequency, channels)` `'channels_first'` if you want `(batch, channels, time, frequency)` Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format()) output_data_format (`str`): the audio data format of output waveform batch. `'channels_last'` if it's `(batch, time, channels)` `'channels_first'` if it's `(batch, channels, time)` Defaults to the setting of your Keras configuration. (tf.keras.backend.image_data_format()) **kwargs: Keyword args for the parent keras layer (e.g., `name`) Example: :: input_shape = (3, 513, 1) # 3 frames, 513 frequency bins, 1 channel # and input dtype is complex model = Sequential() model.add(kapre.InverseSTFT(n_fft=1024, hop_length=512, input_shape=input_shape)) # now the shape is (batch, time=2048, ch=1) """ def __init__( self, n_fft=2048, win_length=None, hop_length=None, forward_window_name=None, input_data_format='default', output_data_format='default', **kwargs, ): super(InverseSTFT, self).__init__(**kwargs) backend.validate_data_format_str(input_data_format) backend.validate_data_format_str(output_data_format) if win_length is None: win_length = n_fft if hop_length is None: hop_length = win_length // 4 self.n_fft = n_fft self.win_length = win_length self.hop_length = hop_length self.forward_window_name = forward_window_name self.window_fn = tf.signal.inverse_stft_window_fn( frame_step=hop_length, forward_window_fn=backend.get_window_fn(forward_window_name) ) idt, odt = input_data_format, output_data_format self.output_data_format = K.image_data_format() if odt == _CH_DEFAULT_STR else odt self.input_data_format = K.image_data_format() if idt == _CH_DEFAULT_STR else idt
[docs] def call(self, x): """ Compute inverse STFT of the input STFT. Args: x (complex `Tensor`): batch of STFTs, (batch, ch, time, freq) or (batch, time, freq, ch) depending on `input_data_format` Return: (`float`): audio signals of x. Shape: 1D batch shape. I.e., (batch, time, ch) or (batch, ch, time) depending on `output_data_format` """ stfts = x # (batch, ch, time, freq) if input_data_format == 'channels_first'. # (batch, time, freq, ch) if input_data_format == 'channels_last'. # this is needed because tf.signal.stft lives in channels_first land. if self.input_data_format == _CH_LAST_STR: stfts = tf.transpose(stfts, perm=(0, 3, 1, 2)) # now always (b, ch, t, f) waveforms = tf.signal.inverse_stft( stfts=stfts, frame_length=self.win_length, frame_step=self.hop_length, fft_length=self.n_fft, window_fn=self.window_fn, name='%s_tf.signal.istft' % self.name, ) # (batch, ch, time) if self.output_data_format == _CH_LAST_STR: waveforms = tf.transpose(waveforms, perm=(0, 2, 1)) # (batch, time, ch) return waveforms
def get_config(self): config = super(InverseSTFT, self).get_config() config.update( { 'n_fft': self.n_fft, 'win_length': self.win_length, 'hop_length': self.hop_length, 'forward_window_name': self.forward_window_name, 'input_data_format': self.input_data_format, 'output_data_format': self.output_data_format, } ) return config
[docs]class Magnitude(Layer): """Compute the magnitude of the complex input, resulting in a float tensor Example: :: input_shape = (2048, 1) # mono signal model = Sequential() model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape)) mode.add(Magnitude()) # now the shape is (batch, n_frame=3, n_freq=513, ch=1) and dtype is float """
[docs] def call(self, x): """ Args: x (complex `Tensor`): input complex tensor Returns: (float `Tensor`): magnitude of `x` """ return tf.abs(x)
[docs]class Phase(Layer): """Compute the phase of the complex input in radian, resulting in a float tensor Includes option to use approximate phase algorithm this will return the same results as the PhaseTflite layer (the tflite compatible layer). Args: approx_atan_accuracy (`int`): if `None` will use tf.math.angle() to calculate the phase accurately. If an `int` this is the number of iterations to calculate the approximate atan() using a tflite compatible method. the higher the number the more accurate e.g. approx_atan_accuracy=29000. You may want to experiment with adjusting this number: trading off accuracy with inference speed. Example: :: input_shape = (2048, 1) # mono signal model = Sequential() model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape)) model.add(Phase()) # now the shape is (batch, n_frame=3, n_freq=513, ch=1) and dtype is float """ def __init__(self, approx_atan_accuracy=None, **kwargs): super(Phase, self).__init__(**kwargs) self.approx_atan_accuracy = approx_atan_accuracy
[docs] def call(self, x): """ Args: x (complex `Tensor`): input complex tensor Returns: (float `Tensor`): phase of `x` (Radian) """ if self.approx_atan_accuracy: return atan2_tflite(tf.math.imag(x), tf.math.real(x), n=self.approx_atan_accuracy) return tf.math.angle(x)
def get_config(self): config = super(Phase, self).get_config() config.update( { 'tflite_phase_accuracy': self.approx_atan_accuracy, } ) return config
[docs]class MagnitudeToDecibel(Layer): """A class that wraps `backend.magnitude_to_decibel` to compute decibel of the input magnitude. Args: ref_value (`float`): an input value that would become 0 dB in the result. For spectrogram magnitudes, ref_value=1.0 usually make the decibel-scaled output to be around zero if the input audio was in [-1, 1]. amin (`float`): the noise floor of the input. An input that is smaller than `amin`, it's converted to `amin. dynamic_range (`float`): range of the resulting value. E.g., if the maximum magnitude is 30 dB, the noise floor of the output would become (30 - dynamic_range) dB Example: :: input_shape = (2048, 1) # mono signal model = Sequential() model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape)) model.add(Magnitude()) model.add(MagnitudeToDecibel()) # now the shape is (batch, n_frame=3, n_freq=513, ch=1) and dtype is float """ def __init__(self, ref_value=1.0, amin=1e-5, dynamic_range=80.0, **kwargs): super(MagnitudeToDecibel, self).__init__(**kwargs) self.ref_value = ref_value self.amin = amin self.dynamic_range = dynamic_range
[docs] def call(self, x): """ Args: x (`Tensor`): float tensor. Can be batch or not. Something like magnitude of STFT. Returns: (`Tensor`): decibel-scaled float tensor of `x`. """ return backend.magnitude_to_decibel( x, ref_value=self.ref_value, amin=self.amin, dynamic_range=self.dynamic_range )
def get_config(self): config = super(MagnitudeToDecibel, self).get_config() config.update( { 'amin': self.amin, 'dynamic_range': self.dynamic_range, 'ref_value': self.ref_value, } ) return config
[docs]class ApplyFilterbank(Layer): """ Apply a filterbank to the input spectrograms. Args: filterbank (`Tensor`): filterbank tensor in a shape of (n_freq, n_filterbanks) data_format (`str`): specifies the data format of batch input/output **kwargs: Keyword args for the parent keras layer (e.g., `name`) Example: :: input_shape = (2048, 1) # mono signal n_fft = 1024 n_hop = n_fft // 2 kwargs = { 'sample_rate': 22050, 'n_freq': n_fft // 2 + 1, 'n_mels': 128, 'f_min': 0.0, 'f_max': 8000, } model = Sequential() model.add(kapre.STFT(n_fft=n_fft, hop_length=n_hop, input_shape=input_shape)) model.add(Magnitude()) # (batch, n_frame=3, n_freq=n_fft // 2 + 1, ch=1) and dtype is float model.add(ApplyFilterbank(type='mel', filterbank_kwargs=kwargs)) # (batch, n_frame=3, n_mels=128, ch=1) """ def __init__( self, type, filterbank_kwargs, data_format='default', **kwargs, ): backend.validate_data_format_str(data_format) self.type = type self.filterbank_kwargs = filterbank_kwargs if type == 'log': self.filterbank = _log_filterbank = backend.filterbank_log(**filterbank_kwargs) elif type == 'mel': self.filterbank = _mel_filterbank = backend.filterbank_mel(**filterbank_kwargs) if data_format == _CH_DEFAULT_STR: self.data_format = K.image_data_format() else: self.data_format = data_format if self.data_format == _CH_FIRST_STR: self.freq_axis = 3 else: self.freq_axis = 2 super(ApplyFilterbank, self).__init__(**kwargs)
[docs] def call(self, x): """ Apply filterbank to `x`. Args: x (`Tensor`): float tensor in 2D batch shape. """ # x: 2d batch input. (b, t, fr, ch) or (b, ch, t, fr) output = tf.tensordot(x, self.filterbank, axes=(self.freq_axis, 0)) # ch_last -> (b, t, ch, new_fr). ch_first -> (b, ch, t, new_fr) if self.data_format == _CH_LAST_STR: output = tf.transpose(output, (0, 1, 3, 2)) return output
def get_config(self): config = super(ApplyFilterbank, self).get_config() config.update( { 'type': self.type, 'filterbank_kwargs': self.filterbank_kwargs, 'data_format': self.data_format, } ) return config
[docs]class Delta(Layer): """Calculates delta, a local estimate of the derivative along time axis. See torchaudio.functional.compute_deltas or librosa.feature.delta for more details. Args: win_length (int): Window length of the derivative estimation. Defaults to 5 mode (`str`): Specifies pad mode of `tf.pad`. Case-insensitive. Defaults to 'symmetric'. Can be 'symmetric', 'reflect', 'constant', or whatever `tf.pad` supports. Example: :: input_shape = (2048, 1) # mono signal model = Sequential() model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape)) model.add(kapre.Magnitude()) model.add(Delta()) # (batch, n_frame=3, n_freq=513, ch=1) and dtype is float """ def __init__(self, win_length=5, mode='symmetric', data_format='default', **kwargs): backend.validate_data_format_str(data_format) if not win_length >= 3: raise ValueError( 'win_length should be equal or bigger than 3, but it is %d' % win_length ) if win_length % 2 != 1: raise ValueError('win_length should be an odd number, but it is %d' % win_length) if mode.lower() not in ('symmetric', 'reflect', 'constant'): raise ValueError( 'mode.lower() should be one of {}'.format(str(('symmetric', 'reflect', 'constant'))) + 'but it is {}'.format(mode) ) if data_format == _CH_DEFAULT_STR: self.data_format = K.image_data_format() else: self.data_format = data_format self.win_length = win_length self.mode = mode self.n = (self.win_length - 1) // 2 # half window length self.denom = 2 * sum([_n ** 2 for _n in range(1, self.n + 1, 1)]) # denominator super(Delta, self).__init__(**kwargs)
[docs] def call(self, x): """ Args: x (`Tensor`): a 2d batch (b, t, f, ch) or (b, ch, t, f) Returns: (`Tensor`): A tensor with the same shape as input data. """ if self.data_format == 'channels_first': x = K.permute_dimensions(x, (0, 2, 3, 1)) x = tf.pad( x, tf.constant([[0, 0], [self.n, self.n], [0, 0], [0, 0]]), mode=self.mode ) # pad over time kernel = K.arange(-self.n, self.n + 1, 1, dtype=K.floatx()) kernel = K.reshape(kernel, (-1, 1, 1, 1)) # time, freq, in_ch, out_ch x = K.conv2d(x, kernel, data_format=_CH_LAST_STR) / self.denom if self.data_format == _CH_FIRST_STR: x = K.permute_dimensions(x, (0, 3, 1, 2)) return x
def get_config(self): config = super(Delta, self).get_config() config.update( {'win_length': self.win_length, 'mode': self.mode, 'data_format': self.data_format} ) return config
[docs]class ConcatenateFrequencyMap(Layer): """Addes a frequency information channel to spectrograms. The added frequency channel (=frequency map) has a linearly increasing values from 0.0 to 1.0, indicating the normalize frequency of a time-frequency bin. This layer can be applied to input audio spectrograms or any feature maps so that the following layers can be conditioned on the frequency. (Imagine something like positional encoding in NLP but the position is on frequency axis). A combination of `ConcatenateFrequencyMap` and `Conv2D` is known as frequency-aware convolution (see References). For your convenience, such a layer is supported by `karep.composed.get_frequency_aware_conv2d()`. Args: data_format (str): specifies the data format of batch input/output. **kwargs: Keyword args for the parent keras layer (e.g., `name`) Example: :: input_shape = (2048, 1) # mono signal model = Sequential() model.add(kapre.STFT(n_fft=1024, hop_length=512, input_shape=input_shape)) model.add(kapre.Magnitude()) # (batch, n_frame=3, n_freq=513, ch=1) and dtype is float model.add(kapre.ConcatenateFrequencyMap()) # (batch, n_frame=3, n_freq=513, ch=2) # now add your model mode.add(keras.layers.Conv2D(16, (3, 3), strides=(2, 2), activation='relu') # you can concatenate frequency map before other conv layers, # but probably, you wouldn't want to add it right before batch normalization. model.add(kapre.ConcatenateFrequencyMap()) model.add(keras.layers.Conv2D(32, (3, 3), strides=(1, 1), activation='relu') model.add(keras.layers.MaxPooling2D((2, 2))) # length of frequency axis doesn't matter References: Koutini, K., Eghbal-zadeh, H., & Widmer, G. (2019). `Receptive-Field-Regularized CNN Variants for Acoustic Scene Classification <https://arxiv.org/abs/1909.02859>`_. In Proceedings of the Detection and Classification of Acoustic Scenes and Events 2019 Workshop (DCASE2019). """ def __init__(self, data_format='default', **kwargs): backend.validate_data_format_str(data_format) if data_format == _CH_DEFAULT_STR: self.data_format = K.image_data_format() else: self.data_format = data_format self.data_format = data_format super(ConcatenateFrequencyMap, self).__init__(**kwargs)
[docs] def call(self, x): """ Args: x (`Tensor`): a 2d batch (b, t, f, ch) or (b, ch, t, f) Returns: x (`Tensor`): a 2d batch (b, t, f, ch + 1) or (b, ch + 1, t, f) """ return self._concat_frequency_map(x)
def _concat_frequency_map(self, inputs): shape = tf.shape(inputs) time_axis, freq_axis, ch_axis = (1, 2, 3) if self.data_format == _CH_LAST_STR else (2, 3, 1) batch_size, n_freq, n_time, n_ch = ( shape[0], shape[freq_axis], shape[time_axis], shape[ch_axis], ) # freq_info shape: n_freq freq_map_1d = tf.cast(tf.linspace(start=0.0, stop=1.0, num=n_freq), dtype=tf.float32) new_shape = (1, 1, -1, 1) if self.data_format == _CH_LAST_STR else (1, 1, 1, -1) freq_map_1d = tf.reshape(freq_map_1d, new_shape) # 4D now multiples = ( (batch_size, n_time, 1, 1) if self.data_format == _CH_LAST_STR else (batch_size, 1, n_time, 1) ) freq_map_4d = tf.tile(freq_map_1d, multiples) return tf.concat([inputs, freq_map_4d], axis=ch_axis) def get_config(self): config = super(ConcatenateFrequencyMap, self).get_config() config.update( { 'data_format': self.data_format, } ) return config