torchiva package#
Subpackages#
Submodules#
torchiva.auxiva_ip module#
- class torchiva.auxiva_ip.AuxIVA_IP(n_iter=10, n_src=None, model=None, proj_back_mic=0, eps=None)#
Bases:
DRBSSBaseIndependent vector analysis (IVA) with iterative projection (IP) update 5.
We do not support ILRMA-T with IP updates.
- Parameters
n_iter (int, optional) – The number of iterations. (default:
10)n_src (int, optional) – The number of sources to be separated. When
n_src < n_chan, a computationally cheaper variant (OverIVA) 6 is used. If set toNone,n_srcis set ton_chan(default:None)model (torch.nn.Module, optional) – The model of source distribution. If
None, spherical Laplace is used (default:None).proj_back_mic (int, optional) – The reference mic index to perform projection back. If set to
None, projection back is not applied (default:0).eps (float, optional) – A small constant to make divisions and the like numerically stable (default:
None).
- forward(X, n_iter=None, n_src=None, model=None, proj_back_mic=None, eps=None)#
- Parameters
X (torch.Tensor) – The input mixture in STFT-domain,
shape (..., n_chan, n_freq, n_frames)- Returns
Y – The separated signal in STFT-domain
- Return type
torch.Tensor,
shape (..., n_src, n_freq, n_frames)
Note
- This class can handle two BSS methods with IP update rule depending on the specified arguments:
AuxIVA-IP:
n_chan==n_src, model=LaplaceMoldel() or GaussMoldel()ILRMA-IP:
n_chan==n_src, model=NMFModel()OverIVA_IP 6:
n_taps=0, n_delay=0, n_chan==n_src, model=NMFModel()
References
- 5
N. Ono, “Stable and fast update rules for independent vector analysis based on auxiliary function technique”, WASSPA, 2011.
- 6(1,2)
R. Scheibler, and N Ono, “Independent vector analysis with more microphones than sources”, WASSPA, 2019, https://arxiv.org/pdf/1905.07880.pdf.
- forward(X, n_iter=None, n_src=None, model=None, proj_back_mic=None, eps=None, verbose=False)#
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool#
- torchiva.auxiva_ip.cost(model, Y, W, J=None, g=None)#
- torchiva.auxiva_ip.orthogonal_constraint(W, Cx, load=0.0001)#
- torchiva.auxiva_ip.projection_back_from_demixing_matrix(Y, W, J=None, ref_mic=0, load=0.0001)#
- Parameters
Y (torch.Tensor (..., n_channels, n_frequencies, n_frames)) – The demixed signals
W (torch.Tensor (..., n_frequencies, n_channels, n_channels)) – The demixing matrix
ref_mic (int, optional) – The reference channel
eps (float, optional) – A diagonal loading factor for the solve method
- Return type
Tensor
torchiva.auxiva_ip2 module#
- class torchiva.auxiva_ip2.AuxIVA_IP2(n_iter=10, model=None, proj_back_mic=0, eps=None)#
Bases:
DRBSSBaseBlind source separation based on independent vector analysis with alternating updates of the mixing vectors 7
- Parameters
n_iter (int, optional) – The number of iterations (default:
10).model (torch.nn.Module, optional) – The model of source distribution. If
None, spherical Laplace is used (default:None).proj_back_mic (int, optional) – The reference mic index to perform projection back. If set to
None, projection back is not applied (default:0).eps (float, optional) – A small constant to make divisions and the like numerically stable (default:
None).
- forward(X, n_iter=None, model=None, proj_back_mic=None, eps=None)#
- Parameters
X (torch.Tensor) – The input mixture in STFT-domain,
shape (..., n_chan, n_freq, n_frames)- Returns
X – The separated signal in STFT-domain.
- Return type
torch.Tensor,
shape (..., n_chan, n_freq, n_frames)
References
- 7
N. Ono, “Fast stereo independent vector analysis and its implementation on mobile phone”, IWAENC, 2012.
- forward(X, n_iter=None, model=None, proj_back_mic=None, eps=None)#
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool#
- torchiva.auxiva_ip2.spatial_model_update_ip2(Xo, weights, W=None, A=None, eps=1e-05)#
Apply the spatial model update via the generalized eigenvalue decomposition. This method is specialized for two channels.
- Parameters
Xo (torch.Tensor, shape (..., n_frequencies, n_channels, n_frames)) – The microphone input signal with n_chan == 2
weights (torch.Tensor, shape (..., n_frequencies, n_channels, n_frames)) – The weights obtained from the source model to compute the weighted statistics
- Returns
X – The updated source estimates
- Return type
torch.Tensor, shape (n_frequencies, n_channels, n_frames)
torchiva.base module#
- class torchiva.base.BFBase(mask_model, ref_mic=0, eps=1e-05)#
Bases:
Module- property eps#
- property ref_mic#
-
training:
bool#
- class torchiva.base.DRBSSBase(n_iter=10, n_taps=0, n_delay=0, n_src=None, model=None, proj_back_mic=0, use_dmc=False, eps=1e-05)#
Bases:
Module- property eps#
- property n_delay#
- property n_iter#
- property n_src#
- property n_taps#
- property proj_back_mic#
-
training:
bool#
- property use_dmc#
- class torchiva.base.SourceModelBase#
Bases:
ModuleAn abstract class to represent source models
- Parameters
X (numpy.ndarray or torch.Tensor, shape (..., n_frequencies, n_frames)) – STFT representation of the signal
- Returns
P – The inverse of the source power estimate
- Return type
numpy.ndarray or torch.Tensor, shape (…, n_frequencies, n_frames)
- reset()#
The reset method is intended for models that have some internal state that should be reset for every new signal.
By default, it does nothing and should be overloaded when needed by a subclass.
-
training:
bool#
torchiva.beamformer module#
- class torchiva.beamformer.GEVBeamformer(mask_model, ref_mic=0, eps=1e-05)#
Bases:
BFBaseImplementation of GEV beamformer. This class is basically assumes DNN-based beamforming.
- Parameters
mask_model (torch.nn.Module) – A function that is given one spectrogram and returns 2 masks of the same size as the input.
ref_mic (int, optional) – Reference channel (default:
0)eps (float, optional) – A small constant to make divisions and the like numerically stable (default:
1e-5).
- forward(X, mask_model=None, ref_mic=None, eps=None)#
- Parameters
X (torch.Tensor) – The input mixture in STFT-domain,
shape (..., n_chan, n_freq, n_frames)- Returns
Y – The separated signal in STFT-domain
- Return type
torch.Tensor,
shape (..., n_src, n_freq, n_frames)
- forward(X, mask_model=None, ref_mic=None, eps=None)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool#
- class torchiva.beamformer.MVDRBeamformer(mask_model, ref_mic=0, eps=1e-05, mvdr_type='rtf', n_power_iter=None)#
Bases:
BFBaseImplementation of MVDR beamformer. This class is basically assumes DNN-based beamforming. also supports the case of estimating three masks
- Parameters
mask_model (torch.nn.Module) – A function that is given one spectrogram and returns 2 or 3 masks of the same size as the input. When 3 masks (1 for target and the rest 2 for noise) are etimated, they are utilized as in 10
ref_mic (int, optional) – Reference channel (default:
0)eps (float, optional) – A small constant to make divisions and the like numerically stable (default:
1e-5).mvdr_type (str, optional) – The way to obtain the MVDR weight. If set to
rtf, relative transfer function is computed to obtain MVDR. If set to ‘scm’, MVDR weight is obtained directly with spatial covariance matrices 11 (default:rtf).n_power_iter (int, optional) – Use the power iteration method to compute the relative transfer function instead of the full generalized eigenvalue decomposition (GEVD). The number of iteration desired should be provided. If set to
None, the full GEVD is used (default:None).
- forward(X, mask_model=None, ref_mic=None, eps=None, mvdr_type=None, n_power_iter=None)#
- Parameters
X (torch.Tensor) – The input mixture in STFT-domain,
shape (..., n_chan, n_freq, n_frames)- Returns
Y – The separated signal in STFT-domain
- Return type
torch.Tensor,
shape (..., n_src, n_freq, n_frames)
References
- 10
C. Boeddeker et al., “Convolutive Transfer Function Invariant SDR training criteria for Multi-Channel Reverberant Speech Separation”, ICASSP, 2021.
- 11
Mehrez Souden, Jacob Benesty, and Sofiene Affes, “On optimal frequency-domain multichannel linear filtering for noise reduction”, IEEE Trans. on audio, speech, and lang. process., 2009.
- forward(X, mask_model=None, ref_mic=None, eps=None, mvdr_type=None, n_power_iter=None)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool#
- class torchiva.beamformer.MWFBeamformer(mask_model, ref_mic=0, eps=1e-05, time_invariant=True)#
Bases:
BFBaseImplementation of MWF beamformer described in 12. This class is basically assumes DNN-based beamforming.
- Parameters
mask_model (torch.nn.Module) – A function that is given one spectrogram and returns 2 masks of the same size as the input.
ref_mic (int, optional) – Reference channel (default:
0)eps (float, optional) – A small constant to make divisions and the like numerically stable (default:
1e-5).time_invariant (bool, optional) – If set to
True, this flag indicates that we want to use the time-invariant version of MWF. If set toFalse, the time-varying MWF is used instead (default:True).
- forward(X, mask_model=None, ref_mic=None, eps=None, time_invariant=None)#
- Parameters
X (torch.Tensor) – The input mixture in STFT-domain,
shape (..., n_chan, n_freq, n_frames)- Returns
Y – The separated signal in STFT-domain
- Return type
torch.Tensor,
shape (..., n_src, n_freq, n_frames)
References
- 12
Y. Masuyama et al., “Consistency-aware multi-channel speech enhancement using deep neural networks”, ICASSP, 2020.
- forward(X, mask_model=None, ref_mic=None, eps=None, time_invariant=None)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool#
- torchiva.beamformer.compute_gev_bf(covmat_target, covmat_noise, ref_mic=0)#
- torchiva.beamformer.compute_mvdr_bf(covmat_noise, steering_vector, eps=1e-05)#
- Return type
Tensor
- torchiva.beamformer.compute_mvdr_bf2(covmat_target, covmat_noise, ref_mic=0, eps=1e-05)#
- Return type
Tensor
- torchiva.beamformer.compute_mvdr_rtf_eigh(covmat_target, covmat_noise, ref_mic=0, power_iterations=None)#
Compute the Relative Transfer Function
- Parameters
covmat_target (torch.Tensor, (..., n_channels, n_channels)) – The covariance matrices of the target signal
covmat_noise (torch.Tensor, (..., n_channels, n_channels)) – The covariance matrices of the noise signal
ref_mic (int) – The channel used as the reference
power_iterations (int, optional) – An integer can be provided. If it is provided, a power iteration algorithm is used instead of the generalized eigenvalue decomposition (GEVD). If it is not provided, the regular algorithm for GEVD is used.
- Return type
Tensor
- torchiva.beamformer.compute_mwf_bf(covmat_target, covmat_noise, eps=1e-05, ref_mic=None)#
Compute the multichannel Wiener filter (MWF) for a given target covariance matrix and noise covariance matrix.
- Parameters
covmat_target (torch.Tensor, (..., n_channels, n_channels)) – The covariance matrices of the target signal
covmat_noise (torch.Tensor, (..., n_channels, n_channels)) – The covariance matrices of the noise signal
ref_mic (int, optional) – If a reference channel is provide, only the filter corresponding to that channel is computed. If it is not provided, all the filters are computed and the returned tensor is a square matrix.
- Returns
mwf – The multichannel Wiener filter, if ref_mic is not provide, the last two dimensions for a square matrix the size of the number of channels. If ref_mic is provide, then there is one less dimension, and the length of the last dimension the number of channels.
- Return type
torch.Tensor, (…, n_channels, n_channels) or (…, n_channels)
torchiva.dtypes module#
- torchiva.dtypes.dtype_cpx2f(t)#
- Return type
dtype
- torchiva.dtypes.dtype_f2cpx(t)#
- Return type
dtype
- torchiva.dtypes.is_complex_type(t)#
- Return type
bool
torchiva.fftconvolve module#
- torchiva.fftconvolve.fftconvolve(x1, x2, mode='full', dim=-1)#
Simple function for computing the convolution of x1 with x2 via frequency domain using the FFT. We do not implement overlap add yet. :type x1: :param x1: The first array :type x1: Tensor (…, n_samples_1) :type x2: :param x2: The second array :type x2: Tensor (…, n_samples_2) :type mode: :param mode: The truncation mode :type mode: str
torchiva.five module#
- class torchiva.five.FIVE(n_iter=10, model=None, proj_back_mic=0, eps=None, n_power_iter=None)#
Bases:
DRBSSBaseFast independent vector extraction (FIVE) 8. FIVE extracts one source from the input signal.
- Parameters
n_iter (int, optional) – The number of iterations (default:
10).model (torch.nn.Module, optional) – The model of source distribution (default:
LaplaceModel).proj_back_mic (int, optional) – The reference mic index to perform projection back. If set to
None, projection back is not applied (default:0).eps (float, optional) – A small constant to make divisions and the like numerically stable (default:
None).n_power_iter (int, optional) – The number of power iterations. If set to
None, eigenvector decomposition is used instead. (default:None)
- forward(X, n_iter=None, model=None, proj_back_mic=None, eps=None)#
- Parameters
X (torch.Tensor) – The input mixture in STFT-domain,
shape (..., n_chan, n_freq, n_frames)- Returns
Y – The extracted one signal in STFT-domain.
- Return type
torch.Tensor,
shape (..., n_freq, n_frames)
References
- 8
R. Scheibler, and N Ono, “Fast independent vector extraction by iterative SINR maximization”, ICASSP, 2020, https://arxiv.org/pdf/1910.10654.pdf.
- forward(X, n_iter=None, model=None, proj_back_mic=None, eps=None, n_power_iter=None)#
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool#
- torchiva.five.adjust_global_scale(Y, ref)#
- torchiva.five.normalize(x)#
- torchiva.five.smallest_eigenvector_eigh_cpu(V)#
- torchiva.five.smallest_eigenvector_power_method(V, n_iter=10)#
torchiva.linalg module#
- torchiva.linalg.bmm(input, mat2)#
- Return type
Tensor
- torchiva.linalg.diagonal_loading(A, d)#
Load the diagonal of A with the vector d
- torchiva.linalg.divide(num, denom, eps=1e-07)#
- torchiva.linalg.eigh(A, B=None, eps=1e-15, use_eigh_cpu=False)#
Eigenvalue decomposition of a complex Hermitian symmetric matrix
- Return type
Tuple[Tensor,Tensor]
- torchiva.linalg.eigh_2x2(A, B=None, eps=0.0)#
Specialized routine for batched 2x2 EVD and GEVD for complex hermitian matrices
- Return type
Tuple[Tensor,Tensor]
- torchiva.linalg.eigh_wrapper(V, use_cpu=True)#
- torchiva.linalg.hankel_view(x, n_rows)#
return a view of x as a Hankel matrix
- Return type
Tensor
- torchiva.linalg.hermite(A, dim1=-2, dim2=-1)#
- torchiva.linalg.inv_2x2(W, eps=1e-06)#
- torchiva.linalg.inv_loaded(A, load=1e-06)#
- torchiva.linalg.mag(x)#
- torchiva.linalg.mag_sq(x)#
- torchiva.linalg.multiply(tensor1, tensor2)#
- torchiva.linalg.solve_loaded(A, b, load=1e-06)#
- torchiva.linalg.solve_loaded_general(A, b, load=1e-05, eps=1e-05)#
torchiva.models module#
torchiva.nn module#
- class torchiva.nn.BSSSeparator(n_fft, n_iter=20, hop_length=None, window=None, n_taps=0, n_delay=0, n_src=None, algo=SepAlgo.T_ISS, source_model=None, proj_back_mic=0, use_dmc=False, n_power_iter=None, eps=1e-05)#
Bases:
Module- property algo#
- forward(x, reset=True)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool#
torchiva.parameters module#
torchiva.scaling module#
- class torchiva.scaling.Scaling#
Bases:
ModuleWe should implement the scaling step as a pytorch model
- forward(X)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool#
- torchiva.scaling.lp_norm(E, p=1)#
- Return type
Tensor
- torchiva.scaling.lpq_norm(E, p=1, q=2, axis=0)#
- Return type
Tensor
- torchiva.scaling.minimum_distortion(Y, ref, p=None, q=None, rtol=0.01, max_iter=100, model=None)#
This function computes the frequency-domain filter that minimizes the sum of errors to a reference signal with a mixed-norm. This is a sparse version of the projection back that is commonly used to solve the scale ambiguity in BSS.
- Parameters
Y (array_like (n_frequencies, n_channels, n_frames)) – The STFT data to project back on the reference signal
ref (array_like (n_frames, n_freq)) – The reference signal
p (float (0 < p <= 2)) – The norm to use to measure distortion
q (float (0 < p <= q <= 2)) – The other exponent when using a mixed norm to measure distortion
max_iter (int, optional) – Maximum number of iterations
rtol (float, optional) – Stop the optimization when the algorithm makes less than rtol relative progress
model (torch.nn.Module) – An optional learnable block to replace the MM weights
- Return type
Tensor
- torchiva.scaling.minimum_distortion_l2(Y, ref)#
This function computes the frequency-domain filter that minimizes the squared error to a reference signal. This is commonly used to solve the scale ambiguity in BSS.
- Parameters
Y (torch.Tensor (..., n_channels, n_frequencies, n_frames)) – The STFT data to project back on the reference signal
ref (torch.Tensor (..., n_frequencies, n_frames)) – The reference signal
- Return type
Tensor
- torchiva.scaling.minimum_distortion_l2_phase(Y, ref)#
This function computes the frequency-domain filter that minimizes the squared error to a reference signal. This is commonly used to solve the scale ambiguity in BSS.
- Parameters
Y (torch.Tensor (..., n_channels, n_frequencies, n_frames)) – The STFT data to project back on the reference signal
ref (torch.Tensor (..., n_frequencies, n_frames)) – The reference signal
- Return type
Tensor
- torchiva.scaling.projection_back(Y, ref)#
Solves the scale ambiguity according to Murata et al., 2001. This technique uses the steering vector value corresponding to the demixing matrix obtained during separation.
- Parameters
Y (torch.Tensor (n_batch, n_channels, n_frequencies, n_frames)) – The STFT data to project back on the reference signal
ref (torch.Tensor (..., n_frequencies, n_frames)) – The reference signal
- Return type
NoReturn
torchiva.stft module#
- class torchiva.stft.STFT(n_fft, hop_length=None, window=None, dtype=None)#
Bases:
Module- forward(x)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- property hop_length: int#
- inv(x)#
- property n_fft: int#
- property n_freq: int#
-
training:
bool#
- property window: Tensor#
- property window_type: str#
torchiva.t_iss module#
- class torchiva.t_iss.T_ISS(n_iter=10, n_taps=0, n_delay=0, n_src=None, model=None, proj_back_mic=0, use_dmc=False, eps=None)#
Bases:
DRBSSBaseJoint dereverberation and separation with time-decorrelation iterative source steering (T-ISS) 1.
Parameters can also be specified during a forward call. In this case, the forward argument is only used in that forward process and does not rewrite class attributes.
- Parameters
n_iter (int, optional) – The number of iterations. (default:
10)n_taps (int, optional) – The length of the dereverberation filter. If set to
0, this method works as the normal AuxIVA with ISS update 2 (default:0).n_delay (int, optional) – The number of delay for dereverberation (default:
0).n_src (int, optional) – The number of sources to be separated. When
n_src < n_chan, a computationally cheaper variant (Over-T-ISS) 3 is used. If set toNone,n_srcis set ton_chan(default:None)model (torch.nn.Module, optional) – The model of source distribution. Mask estimation neural network can also be used. If
None, spherical Laplace is used (default:None).proj_back_mic (int, optional) – The reference mic index to perform projection back. If set to
None, projection back is not applied (default:0).use_dmc (bool, optonal) – If set to
True, memory efficient Demixing Matrix Checkpointing (DMC) 4 is used to compute the gradient. It reduces the memory cost to that of a single iteration when training neural source model (default:False).eps (float, optional) – A small constant to make divisions and the like numerically stable (default:
None).
- forward(n_iter=None, n_taps=None, n_delay=None, n_src=None, model=None, proj_back_mic=None, use_dmc=None, eps=None)#
- Parameters
X (torch.Tensor) – The input mixture in STFT-domain,
shape (..., n_chan, n_freq, n_frames)- Returns
Y – The separated and dereverberated signal in STFT-domain
- Return type
torch.Tensor,
shape (..., n_src, n_freq, n_frames)
Note
- This class can handle various BSS methods with ISS update rule depending on the specified arguments:
IVA-ISS:
n_taps=0, n_delay=0, n_chan==n_src, model=LaplaceMoldel() or GaussMoldel()ILRMA-ISS:
n_taps=0, n_delay=0, n_chan==n_src, model=NMFModel()DNN-IVA-ISS:
n_taps=0, n_delay=0, n_chan==n_src, model=*DNN*OverIVA-ISS:
n_taps=0, n_delay=0, n_chan < n_srcILRMA-T-ISS 1 :
n_taps>0, n_delay>0, n_chan==n_src, model=NMFMoldel()DNN-T-ISS 4 :
n_taps>0, n_delay>0, n_chan==n_src, model=*DNN*Over-T-ISS 3 :
n_taps>0, n_delay>0, n_chan > n_src
References
- 1(1,2,3,4,5)
T. Nakashima, R. Scheibler, M. Togami, and N. Ono, “Joint dereverberation and separation with iterative source steering”, ICASSP, 2021, https://arxiv.org/pdf/2102.06322.pdf.
- 2
R. Scheibler, and N Ono, “Fast and stable blind source separation with rank-1 updates” ICASSP, 2021,
- 3(1,2)
R. Scheibler, W. Zhang, X. Chang, S. Watanabe, and Y. Qian, “End-to-End Multi-speaker ASR with Independent Vector Analysis”, arXiv preprint arXiv:2204.00218, 2022, https://arxiv.org/pdf/2204.00218.pdf.
- 4(1,2)
K. Saijo, and R. Scheibler, “Independence-based Joint Speech Dereverberation and Separation with Neural Source Model”, arXiv preprint arXiv:2110.06545, 2022, https://arxiv.org/pdf/2110.06545.pdf.
- forward(X, n_iter=None, n_taps=None, n_delay=None, n_src=None, model=None, proj_back_mic=None, use_dmc=None, eps=None)#
Defines the computation performed at every call.
Should be overridden by all subclasses. :rtype:
TensorNote
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool#
- torchiva.t_iss.background_update(W, H, C_XX, C_XbarX, eps=1e-05)#
Recomputes J based on W, H, and C_XX = E[X X^H] and C_XbarX = E[ X_bar X^H ]
- torchiva.t_iss.demix_background(X, J)#
- torchiva.t_iss.demix_derev(X, X_bar, W, H)#
- torchiva.t_iss.iss_block_update_type_1(src, X, weights, eps=0.001)#
Compute the update vector for ISS corresponding to update of the sources Equation (9) in 1
- Return type
Tensor
- torchiva.t_iss.iss_block_update_type_2(X, Zs, weights, eps=0.001)#
Compute the update vector for ISS corresponding to update of the taps Equation (9) in 1
- Return type
Tensor
- torchiva.t_iss.iss_block_update_type_3(src, tap, X, X_bar, weights, eps=0.001)#
Compute the update vector for ISS corresponding to update of the taps Equation (9) in 1
- Return type
Tensor
- torchiva.t_iss.iss_updates_with_H(X, X_bar, W, H, weights, J=None, Z=None, eps=0.001)#
ISS updates performed in-place
- Parameters
X (torch.Tensor (..., n_src, n_freq, n_frames)) – Separated signals
Z (torch.Tensor (..., n_chan - n_src, n_freq, n_frames)) –
X_bar (torch.Tensor (..., )) – Delayed versions of the input signal
W (torch.Tensor (..., n_src, n_freq, n_chan)) – The demixing matrix part corresponding to target sources
H (torch.Tensor (..., n_src, n_freq, n_chan, n_taps)) – The dereverberation matrix
J (torch.Tensor (..., n_chan - n_src, n_freq, n_src)) – The demixing matrix part corresponding to background
weights (torch.Tensor (..., n_src, n_freq, n_frames)) – The separation masks
n_src (int, optional) – The number of target sources
eps (float, optional) – A small constant used for numerical stability
- Return type
Tuple[Tensor,Tensor,Tensor,Tensor]
- torchiva.t_iss.over_iss_t_one_iter(Y, X, X_bar, C_XX, C_XbarX, W, H, J, model, eps=0.001)#
- torchiva.t_iss.over_iss_t_one_iter_dmc(X, X_bar, C_XX, C_XbarX, W, H, J, model, eps=0.001, *model_params)#
- torchiva.t_iss.projection_back_weights(W, J=None, ref_mic=0, eps=1e-06)#
torchiva.utils module#
- torchiva.utils.import_name(name)#
- torchiva.utils.instantiate(name, args=None, kwargs=None)#
Get a model by its name :type name:
str:param name: Name of the model class :type name: str :type kwargs:Optional[Dict] :param kwargs: A dict containing all the arguments to the model :type kwargs: dict
- torchiva.utils.select_most_energetic(x, num, dim=-2, dim_reduc=-1)#
Selects the num indices with most power
- Parameters
x (torch.Tensor (n_batch, n_channels, n_samples)) – The input tensor
num (int) – The number of signals to select
dim (
Optional[int]) – The axis where the selection should occurdim_reduc (
Optional[int]) – The axis where to perform the reduction
torchiva.wpe module#
- class torchiva.wpe.WPE(n_iter=3, n_delay=3, n_taps=5, model=None, eps=1e-05)#
Bases:
DRBSSBaseWeighted prediction error (WPE) 9.
- Parameters
n_iter (int, optional) – The number of iterations. (default:
3)n_taps (int, optional) – The length of the dereverberation filter (default:
5).n_delay (int, optional) – The number of delay for dereverberation (default:
3).model (torch.nn.Module, optional) – The model of source distribution. If
None, time-varying Gaussian is used. (default:None).eps (float, optional) – A small constant to make divisions and the like numerically stable (default:
1e-5).
- Returns
Y – The dereverberated signal in STFT-domain.
- Return type
torch.Tensor,
shape (..., n_src, n_freq, n_frames)
References
- 9
T. Nakatani, T. Yoshioka, K. Kinoshita, M. Miyoshi, and B. H. Juang, “Speech dereverberation based on variance-normalized delayed linear prediction”, IEEE Trans. on Audio, Speech, and Lang. Process., 2010.
- forward(X, n_iter=None, n_delay=None, n_taps=None, model=None, eps=None)#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
training:
bool#
- torchiva.wpe.derev(H, X, X_bar)#
- torchiva.wpe.wpe_default_weights(Y, eps=1e-05)#
- Return type
Tensor
- torchiva.wpe.wpe_one_iter(Y, X, X_bar, model=None, eps=1e-05)#
- Parameters
Y (torch.Tensor, (..., n_chan, n_freq, n_frames)) – The current estimate of the dereverberated signal
X (torch.Tensor, (..., n_chan, n_freq, n_frames)) – Input signal
X_bar (torch.Tensor, (..., n_chan, n_freq, n_taps, n_frames)) – Delayed version of input signal
- Returns
H – The updated dereverberation filter weights
- Return type
torch.Tensor, (…, n_freq, n_chan, n_taps, n_chan)