Common implementations of VAE models, such as image VAEs or sequential video VAEs, use the MSE loss for reconstruction, and tune a heuristic weight on the KL divergence term of the objective. This MSE loss corresponds to a log-likelihood of a Gaussian decoder distribution with a certain constant variance. However, as we show in the paper, the assumption of constant variance is actually very problematic! Instead, it is better to learn the variance of the decoder, so that the decoding distribution is calibrated. We propose a simple method for learning such variance that empirically improves sample quality for multiple VAE and sequence VAE models. Moreover, we find that the weight on the KL divergence term can be simply set to one, as our method automatically balances the two terms of the objective. Our method improves performance over naive decoder choices, reduces the need for hyperparameter tuning, and can be implemented in 5 lines of code.
When training a Gaussian Mixture Model with constant variance (left), very different results are obtained for different values of variance. With large variance, the Gaussians aren't able to specialize and collapse into a single Gaussian, while with low variance the Gaussians become overly sharp. Similarly, when setting the variance of the decoding distribution of a VAE to be constant (right), the resulting samples will be blurry if the variance is too high (due to posterior collapse) and unrealistic if the variance is too low. This can be remedied by tuning a heuristic weight on the KL divergence term, however, we observe that we can find the variance that produces good samples by learning the variance end-to-end with the ELBO.
Gaussian Mixture Model
There are several design choices for learning the variance. The naive way would be to output a per pixel variance as another channel of the decoder (left). However, we empirically find that this often gives suboptimal results. Instead, learning a shared variance that is independent of the latent (center) often works better. We can further improve the optimization of this shared variance through an analytic solution (right). We call the resulting method σ-VAE.
A typical Gausssian VAE with unit variance (left) does not control the balance between the reconstruction and the KL-divergence loss and produces suboptimal, blurry samples. σ-VAE automatically balances the objective through learning the variance, which acts as a balance factor on the loss.
We released minimalistic implementations of the method in PyTorch and TensorFlow. It only takes 5 lines of code to try σ-VAE with your VAE model! We have also released an implementation of σ-VAE with the Stochastic Video Generation (SVG) method.