Let $p$ and $q$ be two mixtures of $k$ univariate Gaussians as follows:

We are interested in a discriminator $D$ that distinguishes between the above distributions. Let’s assess $D$’s performance with the WGAN objective

That is, the greater $L$ is, the better our discriminator $D$ is. Let the output range of $D$ be $[0,1]$. $L(D)$ can be written as

Since $D(x)\geq 0$, $L(D)$ is maximized with $D(x)=0$ for $p(x)<q(x)$ and $D(x)=1$ for $p(x)\geq q(x)$. With this observation at hand, one can think of the discriminator as an indicator function, with the optimal one being $D^*(x)=\mathbb{1}{p(x)\geq q(x)}$.

For univariate Gaussians, $D^*(x)=1$ on interavls where $p(x)\geq q(x)$ and $0$ otherwise. To find these intervals, one can look for the zero-crossings of the function $p(x)-q(x)$. From Theorem A.2, the linear combination $p(x)-q(x)$ has at most $2k-1$ zero-crossings.

With $k=2$, we have 3 zero-crossings, which we can find with a root-finding solver given good initial guesses. Subsequently, $p(x)\geq q(x)$ over at most $2k-2=2$ disjoint intervals. What’s left now is computing the lower and upper bounds of these intervals for an optimal discriminator. We will demonstrate this in the following snippets.

import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import root
from scipy.stats import norm

def f(_x, *args):
"""
evaluate p(x) - q(x)
It is assumed that sigma = 1 for all the Gaussians
"""
_params = args
p_x = 0.5 * (norm.pdf(_x, loc=_params['p_mu_1']) + norm.pdf(_x, loc=_params['p_mu_2']))
q_x = 0.5 * (norm.pdf(_x, loc=_params['q_mu_1']) + norm.pdf(_x, loc=_params['q_mu_2']))
return p_x - q_x

def solve_fx(_f, x0, _params):
"""
Find the zero crossing
"""
_res = root(_f, x0, _params)
return float(_res.x)

def get_optimal_bounds(params):
"""
computes the optimal bounds
"""
sorted_param_names = sorted(params, key=params.get)

x = np.linspace(params[sorted_param_names], params[sorted_param_names[-1]], 1000)
y = f(x, params)

crosses = np.where(np.diff(np.sign(y[y != 0])))
positives = np.where(y >= 0)

x0s = x[crosses]
res = root(f, x0s, params)

vals = list(res.x)

if len(vals) == 1:
if y[crosses] > 0:
l_1 = - np.float("inf")
r_1 = vals
l_2 = r_1
r_2 = l_2
else:
l_1 = vals
r_1 = vals
l_2 = vals
r_2 = np.float("inf")
elif len(vals) == 2:
if y[crosses] > 0:
l_1 = - np.float("inf")
r_1 = vals
l_2 = vals
r_2 = np.float("inf")
else:
l_1 = vals
r_1 = 0.5 * (vals + vals)
l_2 = r_1
r_2 = vals
elif len(vals) == 3:
if y[crosses] > 0:
l_1 = - np.float("inf")
r_1 = vals
l_2 = vals
r_2 = vals
else:
l_1 = vals
r_1 = vals
l_2 = vals
r_2 = np.float("inf")
elif len(vals) == 0:
# any arbitrary value
l_1, r_1, l_2, r_2 = -10., 0., 0., 10.
else:
raise Exception("There should be at most 3 crossings!")

return l_1, r_1, l_2, r_2


Let’s visulaize an arbitrary set of intervals, note here we have two mixture of 2 Gaussian, that is there exist at most 3 zero crossings.

# viz intervals for D(x) =  1{-3<=x<= -1} | 1{1<=x<=3}
plot_bounds([-3, -1, 1, 3],
{'p_mu_1': -2, 'p_mu_2': 2, 'q_mu_1': 1, 'q_mu_2': 5}) The bounds are covering regions where $p(x) < q(x)$, this should not be the case for an optimal discriminator $D^*(x)$ as shown below.

# viz opt disc bounds D^*(x)
params = {'p_mu_1': -2, 'p_mu_2': 2, 'q_mu_1': 1, 'q_mu_2': 5}
plot_opt_disc(params) 