How (not) to Train your Generative Model: Scheduled Sampling, Likelihood, Adversary?
Abstract
Modern applications and progress in deep learning research have created renewed interest for generative models of text and of images. However, even today it is unclear what objective functions one should use to train and evaluate these models. In this paper we present two contributions.
Firstly, we present a critique of scheduled sampling, a stateoftheart training method that contributed to the winning entry to the MSCOCO image captioning benchmark in 2015. Here we show that despite this impressive empirical performance, the objective function underlying scheduled sampling is improper and leads to an inconsistent learning algorithm.
Secondly, we revisit the problems that scheduled sampling was meant to address, and present an alternative interpretation. We argue that maximum likelihood is an inappropriate training objective when the endgoal is to generate naturallooking samples. We go on to derive an ideal objective function to use in this situation instead. We introduce a generalisation of adversarial training, and show how such method can interpolate between maximum likelihood training and our ideal training objective. To our knowledge this is the first theoretical analysis that explains why adversarial training tends to produce samples with higher perceived quality.
1 Introduction
Building sophisticated generative models that produce realisticlooking images or text is an important current frontier of unsupervised learning. The renewed interest in generative models can be attributed to two factors. Firstly, thanks to the active investment in machine learning by internet companies, we now have several products and practical usecases for generative models: texture generation(Han et al., 2008), speech synthesis (Ou & Zhang, 2012), image caption generation (Lin et al., 2014; Vinyals et al., 2014), machine translation (Sutskever et al., 2014), conversation and dialogue generation (Vinyals & Le, 2015; Sordoni et al., 2015). Secondly, recent success in generative models, particularly those based on deep representation learning, have raised hopes that our systems may one day reach the sophistication required in these practical use cases.
While noticable progress has been made in generative modelling, in many applications we are still far from generating fully realistic samples. One of the key open questions is what objective functions one should use to train and evaluate generative models (Theis et al., 2015). The model likelihood is often considered the most principled training objective and most research in the past decades has focussed on maximum lieklihood(ML) and approximations thereof Hinton et al. (2006); Hyvärinen (2006); Kingma & Welling (2013). Recently we have seen promising new training strategies such as those based on adversarial networks (Goodfellow et al., 2014; Denton et al., 2015) and kernel moment matching (Li et al., 2015; Dziugaite et al., 2015) which are not — at least on the surface — related to maximum likelihood. Most of this departure from ML was motivated by the fact that the exact likelihood is intractable in the most models. However, some authors have recently observed that even in models whose likelihood is tractable, ML training leads to undesired behaviour, and introduced new training procedures that deliberately differ from maximum likelihood. Here we will focus on scheduled sampling (Bengio et al., 2015) which is an example of this.
In this paper we attempt to clarify what objective functions might work well for the generative scenario and which ones should one avoid. In line with (Theis et al., 2015) and (LacosteJulien et al., 2011), we believe that the objective function used for training should reflect the task we want to ultimately use the model for. In the context of this paper, we focus on generative models that are created with the sole purpose of generating realisticlooking samples from. This narrower definition extends to usecases such as image captioning, texture generation, machine translation and dialogue systems, but excludes tasks such as unsupervised pretraining for supervised learning, semisupervised learning, data compression, denoising and many others.
This paper is organised around the following main contributions:
 scheduled sampling is improper:

In the first half of this paper we focus on autoregressive models for sequence generation. These models are interesting for us mainly because exact maximum likelihood training is tractable, even in relatively complex models such as stacked LSTMs (Bengio et al., 2015; Sutskever et al., 2014; Theis & Bethge, 2015). However, it has been observed that autoregressive generative models trained via ML have some undesired behaviour when they are used to generate samples. We revisit a recent attempt to remedy these problems: scheduled sampling. We reexpress the scheduled sampling training objective in terms of KullbackLeibler divergences, and show that it is in fact an improper training objective. Therefore we recommend to use scheduled sampling with care.
 KLdivergence as a model of perceptual loss:

In the latter part of the paper we seek an alternative solution to the problem scheduled sampling was meant to address. We uncover a more fundamental problem that applies to all generative models: that the likelihood is not the right training objective when the goal is to generate realistic samples. Maximum likelihood can be thought of as minimising the KullbackLeibler divergence between the real data distribution and the probabilistic model . We present a model that suggests generative models should instead be trained to minimise , the KullbackLeibler divergence in the opposite direction. The differences between minimising and are well understood, and explain the observed undesirable behaviour in autoregressive sequence models.
 generalised adversarial training:

Unfortunately, is even harder to optimise than the likelihood, so it is unlikely to yield a viable training procedure. Instead, we suggest to minimise an information quantity which we call generalised JensenShannon divergence. We show that this divergence can effectively interpolate between the behaviour of and , thereby containing both maximum likelihood, and our ideal perceptual objective function as a special case. We also show that generalisations of the adversarial training procedure proposed in (Goodfellow et al., 2014) can be employed to approximately minimise this divergence function. Our analysis also provides a new theoretical explanation for the success of adversarial training in producing qualitatively superior samples.
2 Autoregressive Models for Sequence Generation
In this section we will focus on a particularly useful class of probabilistic models, which we call autoregressive generative models (see e. g. Theis et al., 2012; Larochelle & Murray, 2011; Bengio et al., 2015). An autoregressive probabilistic model explicitly defines the joint distribution over a sequence of symbols recursively as follows:
(1) 
We note that technically the above equation holds for all joint distributions , here we further assume that each of the component distributions are tractable and easy to compute. Autoregressive models are considered relatively easy to train, as the model likelihood is typically tractable. This allows us to train even complicated deep models such as stacked LSTMs in the coherent and well understood framework of maximum likelihood estimation (Theis et al., 2012, 2015).
3 The symptoms
Despite the elegance of a closedform maximum likelihood training, Bengio et al. (2015) have observed out that maximum likelihood training leads to undesirable behaviour when the models are used to generate samples from. In this section we review these symptoms, and throughout this paper we will explore different strategies aimed at explaining and
Typically, when training an AR model, one minimises the log predictive likelihood of the th symbol in each training sentence conditioned on all previous symbols in the sequence that we collectively call the prefix. This can be thought of as a special case of maximum likelihood learning, as the joint likelihood over all symbols in a sequence factorises into these conditionals via the chain rule of probabilities.
When using the trained model to generate sample sequences, we generate each new sequence symbolbysymbol in a recursive fashion: Assuming we already generated a prefix of sybols, we feed that prefix into the conditional model, and ask it to output the predictive distribution for the st character. The st character is then sampled from this distribution and added to the prefix.
Crucially, at training time the RNN only sees prefixes from real training sequences. However, at generation time, it can generate a prefix that is never seen in the training data. Once an unlikely prefix is generated, the model typically has a hard time recovering from the mistake, and will start outputting a seemingly random string of symbols ending up with a sample that has poor perceptual quality and is very unlikely under the true sequence distribution .
4 Symptomatic treatment: Scheduled sampling
In (Bengio et al., 2015), the authors stipulate that the cause of the observed poor behaviour is the disconnect between how the model is trained (it’s always fed prefixes from real data) and how it’s used (it’s always fed synthetic prefixes generated by the model itself). To address this, the authors propose an alternative training strategy called scheduled sampling (SS). In scheduled sampling, the network is sometimes given its own synthetic data as prefix instead of a real prefix at training time. This, the authors argue, simulates the environment in which the model is used when generating samples from it.
More specifically, we turn each training sequence into modified training sequence in a recursive fashion using the following procedure:

for the th symbol we draw from a Bernoulli distribution with parameter to decide whether we keep the original symbol or use one generated by the model

if we decided to replace the symbol, we use the current model RNN to output the predictive distribution of the next symbol given the current prefix, and sample from this predictive distribution

we add to the training loss the log predictive probability of the real th symbol, given the prefix (the prefix at this point may already contain generated characters)

depending on the coinflip above, the original or simulated character is added to the prefix and we continue with the recursion
The method is called scheduled sampling to describe the way the hyperparameter is annealed during training from an initial value of down to . Here, we would like to understand the limiting behaviour of this training procedure, whether and why it is an appropriate way to address the shortcomings of maximum likelihood training.
4.1 Scheduled sampling formulated as KL divergence minimisation
To keep notation simple, let us consider the case of learning sequences of length 2, that is pairs of random symbols and . Our aim is to formulate a closed form training objective that corresponds to scheduled sampling.
If is kept original  rather than replaced by a sample  the scheduled sampling objective in fact remains the same as maximum likelihood. We can understand maximum likelihood as minimising the following KL divergence^{1}^{1}1more precisely, maximum likelihood minimises the crossentropy , where is the differential entropy of training data. between the true data distribution and our approximation :
(2)  
(3) 
Here, and denote marginal distributions of the first symbol under and respectively, while and denote the conditional distributions of the second symbol conditioned on the value of the first symbol being .
The other case we need to consider is when is replaced by a sample from the model, in this case . The training objective can now be expressed as the following divergence:
(4)  
(5) 
Notice how in the second term the KL divergence is now measured from rather than the conditional, this is because the real value of the first symbol is never shown to the model, when it is asked to predict the second symbol .
In scheduled sampling, we choose randomly between the above two cases, so the full SS objective can be described as a convex combination of and above:
(6) 
It is worth noting at this point that this divergence is an idealised form of the scheduled sampling. In the actual algorithm, expectations over and would be implemented by sampling^{2}^{2}2The authors also propose taking argmax of each distribution instead of sampling, this case is harder to analyse but we think our general observations still hold.. This divergence describes the method’s limiting behaviour in the limit of infinite training data.
By rearranging terms we can further express the SS objective as the following KL divergence:
(7)  
(8) 
A very natural requirement for any divergence function used to assess goodness of fit in probabilistic models is that it is minimised when . In statistics, this property is referred to as strictly proper scoring rule estimation (Gneiting & Raftery, 2007). Working with strictly proper divergences guarantees consistency, i. e. that the training procedure can ultimately recover the true , assuming the model class is flexible enough and enough training data is provided. What the above analysis shows us is that scheduled sampling is not a consistent estimation strategy. As , the divergence is globally minimised at the factorised distribution , rather than at the correct joint distribution . The model is still inconsistent when intermediate values are used, in this case the divergence has a global optimum that is somewhere between the true joint and the factorised distribution .
Based on this analysis we suggest that scheduled sampling works by pushling models towards a trivial solution of memorising distribution of symbols conditioned on their position in the sequence, rather than on the prefix of preceding symbols. In recurrent neural network (RNN) terminology, this would means that the optimal architecture under SS uses its hidden states merely to implement a simple counter, and learns to pay no attention whatsoever to the content of the sequence prefix. While this may indeed lead to models that are more likely to recover from mistakes, we believe it fails to address the limitations of maximum likelihood the authors initially set out to solve.
How could an inconsistent training procedure still achieve stateoftheart performance in the image captioning challenge? There are multiple possible explanations to this. We speculate that the optimisation was not run until full convergence, and perhaps an improvement over the maximum likelihood solution was found as a coincidence due to the the interplay between early stopping, random restarts, the specific structure of the model class and the annealing schedule for .
5 The Diagnosis
After discussing scheduled sampling, a method proposed to remedy the symptoms explained in section 3, we now seek a better explanation of why those symptoms exist in the first place. We will now leave the autoregressive model class, and consider probabilistic generative models in their full generality.
The symptoms outlined in Section 3 can be attributed to a mismatch between the loss function used for training (likelihood) and the loss used for evaluating the model (the perceptual quality of samples produced by the model). To fix this problem we need a training objective that more closely matches the perceptual metric used for evaluation, and ideally one that allows for a consistent statistical estimation framework.
5.0.1 A model of noreference perceptual quality assessment
When researchers evaluate their generative models for perceptual quality, they draw samples from it, then  for lack of a better word  eyeball the samples. In visual information processing this is often referred to as noreference perceptual quality assessment (see e. g. Wang et al., 2002). When using the model in an application like caption generation, we typically draw a sample from a conditional model , where represents the context of the query, and present it to a human observer. We would like each sample to pass a Turing test. We want the human observer to feel like is a plausible naturally occurring response, within the context of the query .
In this section, we will propose that the KL divergence can be used as an idealised objective function to describe the noreference perceptual quality assessment scenario. First of all, we make the assumption that the perceived quality of each sample is related to the surprisal under the human observers’ subjective prior of stimuli CITE. We further assume that the human observer has learnt an accurate model of the natural distribution of stimuli, thus, . These two assumptions suggest that in order to optimise our chances in the Turing test scenario, we need to minimise the following crossentropy or perplexity term:
(9) 
Note that this perplexity is the exact opposite average negative log likelihood , with the role of and changed.
However, the objective in Eqn. 9 would be maximised by a model that deterministically picks the most likely stimulus. To enforce diversity one can simultaneously try to maximise the entropy of . This leaves us with the following KL divergence to optimise:
(10) 
It is known that is minimised when , therefore minimising it would correspond to a consistent estimation strategy. However, it is only welldefined when is positive and bounded in the full support of , which is not the case when is an empirical distribution of samples and is a smooth probabilistic model. For this reason, is not viable as a practical training objective in statistical esimation. Still, we can use it as our idealised perceptual quality metric to motivate our choice of practical objective functions.
5.0.2 How does this explain the symptoms?
The differences in behaviour between and are well understood and exploited for example in the context of approximate Bayesian inference (LacosteJulien et al., 2011; MacKay, 2003; Minka, 2001). The differences are most visible when model underspecification is present: imagine trying to model a multimodal with a simpler, unimodal model . Minimising corresponds to moment matching and has a tendency to find models that cover all the modes of , at the cost of placing probability mass where has none. Minimising in this case leads to a modeseeking behaviour: the optimal will typically concentrate around the largest mode of , at the cost of completely ignoring smaller modes. These differences are illustrated visually in Figure 1, panels B and D.
In the context of generative models this means that minimising often leads to models that overgeneralise, and sometimes produce samples that are very unlikely under . This would explain why recurrent neural networks trained via maximum likelihood also have a tendency to produce completely unseen sequences. Minimising will aim to create a model that can generate all the behaviour that is observed in real data, at the cost of introducing behaviours that are never seen. By contrast, if we train a generative model by minimising , the model will very conservatively try to avoid any behaviour that is unlikely under . This comes at the cost of ignoring modes of completely, unless those additional modes can be modelled without introducing probability mass in regions where has none.
Once again, both and define consistent estimation strategies. They differ in the kind of errors they make under severe model misspecification particularly in high dimensions.
6 Generalised Adversarial Training
We theorised that may be a more meaningful training objective if our aim was to improve the perceptual quality of generative models, but it is impractical as an objective function.
Here we show that a generalised version of adversarial training (Goodfellow et al., 2014) can be used to approximate training based on . Adversarial training can be described as minimising an approximation to the JensenShannon divergence between and (Goodfellow et al., 2014; Theis et al., 2015). The JS divergence between and is defined by the following formula:
(11) 
Unlike KL divergence, the JS divergence is symmetric in its arguments, and can be understood as being somewhere between and in terms of its behaviour. One can therefore hope that JSD would behave a bit more like and therefore ultimately tend to produce more realistic samples. Indeed, the behaviour of JSD minimisation under model misspecification is more similar to than as illustrated in Figure 1. Empirically, methods built on adversarial training do tend to produce appealing samples (Goodfellow et al., 2014; Denton et al., 2015).
However, we can even formally show that JS divergence is indeed an interpolation between the two KL divergences in the following sense. Let us consider a more general definition of JensenShannon divergence, parametrised by a nontrivial probability :
(12) 
For any given value of this generalised JensenShannon divergence is not symmetric in its arguments and anymore, instead the following weaker notion of symmetry holds:
(13) 
It is easy to show that divergence converges to in the limit of both and . Crucially, it can be shown that the gradients with respect to at these two extremes recover and , respectively. A proof of this property can be obtained by considering the Taylorexpansion , where is the positive definite Hessian and substituting as follows:
(14)  
(15)  
(16) 
Therefore, we can say that for infinitisemally small values of , is approximately proportional to :
(17) 
And by symmetry in Eqn. 13 we also have that for small values of
(18) 