Computing the Optimal Discriminator between Mixture of Gaussians
Let $p$ and $q$ be two mixtures of $k$ univariate Gaussians as follows:
We are interested in a discriminator $D$ that distinguishes between the above distributions. Let’s assess $D$’s performance with the WGAN objective
That is, the greater $L$ is, the better our discriminator $D$ is. Let the output range of $D$ be $[0,1]$. $L(D)$ can be written as
Since $D(x)\geq 0$, $L(D)$ is maximized with $D(x)=0$ for $p(x)<q(x)$ and $D(x)=1$ for $p(x)\geq q(x)$. With this observation at hand, one can think of the discriminator as an indicator function, with the optimal one being $D^*(x)=\mathbb{1}{p(x)\geq q(x)}$.
For univariate Gaussians, $D^*(x)=1$ on interavls where $p(x)\geq q(x)$ and $0$ otherwise. To find these intervals, one can look for the zero-crossings of the function $p(x)-q(x)$. From Theorem A.2, the linear combination $p(x)-q(x)$ has at most $2k-1$ zero-crossings.
With $k=2$, we have 3 zero-crossings, which we can find with a root-finding solver given good initial guesses. Subsequently, $p(x)\geq q(x)$ over at most $2k-2=2$ disjoint intervals. What’s left now is computing the lower and upper bounds of these intervals for an optimal discriminator. We will demonstrate this in the following snippets.
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import root
from scipy.stats import norm
def f(_x, *args):
"""
evaluate p(x) - q(x)
It is assumed that sigma = 1 for all the Gaussians
"""
_params = args[0]
p_x = 0.5 * (norm.pdf(_x, loc=_params['p_mu_1']) + norm.pdf(_x, loc=_params['p_mu_2']))
q_x = 0.5 * (norm.pdf(_x, loc=_params['q_mu_1']) + norm.pdf(_x, loc=_params['q_mu_2']))
return p_x - q_x
def solve_fx(_f, x0, _params):
"""
Find the zero crossing
"""
_res = root(_f, x0, _params)
return float(_res.x)
def get_optimal_bounds(params):
"""
computes the optimal bounds
"""
sorted_param_names = sorted(params, key=params.get)
x = np.linspace(params[sorted_param_names[0]], params[sorted_param_names[-1]], 1000)
y = f(x, params)
crosses = np.where(np.diff(np.sign(y[y != 0])))[0]
positives = np.where(y >= 0)[0]
x0s = x[crosses]
res = root(f, x0s, params)
vals = list(res.x)
if len(vals) == 1:
if y[crosses[0]] > 0:
l_1 = - np.float("inf")
r_1 = vals[0]
l_2 = r_1
r_2 = l_2
else:
l_1 = vals[0]
r_1 = vals[0]
l_2 = vals[0]
r_2 = np.float("inf")
elif len(vals) == 2:
if y[crosses[0]] > 0:
l_1 = - np.float("inf")
r_1 = vals[0]
l_2 = vals[1]
r_2 = np.float("inf")
else:
l_1 = vals[0]
r_1 = 0.5 * (vals[0] + vals[1])
l_2 = r_1
r_2 = vals[1]
elif len(vals) == 3:
if y[crosses[0]] > 0:
l_1 = - np.float("inf")
r_1 = vals[0]
l_2 = vals[1]
r_2 = vals[2]
else:
l_1 = vals[0]
r_1 = vals[1]
l_2 = vals[2]
r_2 = np.float("inf")
elif len(vals) == 0:
# any arbitrary value
l_1, r_1, l_2, r_2 = -10., 0., 0., 10.
else:
raise Exception("There should be at most 3 crossings!")
return l_1, r_1, l_2, r_2
Let’s visulaize an arbitrary set of intervals, note here we have two mixture of 2 Gaussian, that is there exist at most 3 zero crossings.
# viz intervals for D(x) = 1{-3<=x<= -1} | 1{1<=x<=3}
plot_bounds([-3, -1, 1, 3],
{'p_mu_1': -2, 'p_mu_2': 2, 'q_mu_1': 1, 'q_mu_2': 5})
The bounds are covering regions where $p(x) < q(x)$, this should not be the case for an optimal discriminator $D^*(x)$ as shown below.
# viz opt disc bounds D^*(x)
params = {'p_mu_1': -2, 'p_mu_2': 2, 'q_mu_1': 1, 'q_mu_2': 5}
plot_opt_disc(params)