DeepMind has recently released a paper about population-based training of neural nets. The proposition is a simple asynchronous optimization algorithm, which jointly optimizes a population of models and their hyperparameters. In this post, I aim to demonstrate a simple implementation of the algorithm by reproducing the paper’s Figure 2.



The objective is to maximize the function

def q(theta):
    return 1.2 - np.sum(np.array(theta)**2)

Think of the above function as the validation set performance whose closed-form expression is unknown. However, we can still query its value $($ evaluate $)$ at specific $\theta$. While our goal is to maximize it and achieve a model with good generalization capabilities, we optimize our neural model $($ tune its parameters $)$ with respect to the training set performance. That is, our model is optimized w.r.t. to a surrogate objective $\hat{Q}$ function rather than the objective function we are after.

In an ideal setup, one would like $\hat{Q}=Q$ $($ or at least optimizing $\hat{Q}$ should correspond to optimizing $Q$ as well$)$, but this is not the case in general. The quality of this approximation $($or coupled optimization $)$ is often dependent on the model at hand and how we are training it: the hyperparameters. In other words, when we optimize $\hat{Q}$ w.r.t. $\theta$, we are doing so given $($ conditioned on $)$ a set of hyperparameters, $\hat{Q}(\theta\mid h)$.

In the considered toy example, let the mpirical loss / training set performance by parameterized by $h$ as follows.

def q_hat(theta, h):
    # this is used by `step` function to perform
    # gradient updates 
    return 1.2 - np.sum(np.dot(np.array(h), np.array(theta)**2))

The task now boils down to using $\hat{Q}$ to tune $\theta$ while searching for $h$ to make the tuning process as efficient as possible in $($ indirectly $)$ optimizing $Q$. PBT is one solution to do so. Here is the pesudocode, and below is an implementation of the same.



Let’s start with modeling the population members:

class Member(object):
    def __init__(self, theta=[0., 0.], h=[0., 0.], _id=1, _eta=0.01, _sigma=1e-1):
        self.theta = theta
        self.h = h
        self.id = _id
        self.num_steps = 0
        self.p = q(self.theta)
        self._eta = _eta
        self._sigma = _sigma
        ## for visualization
        self.trace = []
        self.ps = [self.p]

    def eval(self):
        self.p = q(self.theta)

    def ready(self):
        return self.num_steps > 10

    def step(self):
        for i in range(2):
            self.theta[i] +=  self._eta * (-2. * self.h[i] * self.theta[i])
        self.num_steps += 1

    def explore(self):
        for i in range(2):
            self.h[i] = np.clip(self.h[i] + self._sigma * np.random.randn(), 0, 1)

    def exploit(self, other):
        if self.p <= other.p:
            self.theta = list(other.theta)
            self.p = other.p
            self.num_steps = 0
            return True
        else:
            return False

    def log(self):
        self.trace.append(list(self.theta))
        self.ps.append(self.p)

    def __str__(self):
        return 'theta:' + '%.2f' % self.theta[0] + ',%.2f' % self.theta[1] + \
               '| h:' + '%.2f' % self.h[0] + ',%.2f' % self.h[1] + '| id' + str(self.id) + \
               '| p:' + '%.2f' % self.p + '| steps:' + str(self.num_steps)

Here’s PBT implementation. I’ve put flags to show PBT in explore, exploit, grid search modes.

def pbt(grid=False, explore_only=False, exploit_only=False):
    # a check to ensure only one mode is selected for PBT
    assert grid + exploit_only + explore_only <= 1, "at most one flag can be set for PBT modes"
    # init population
    population = [
        Member(theta=[0.9, 0.9], h=[0,1], _id=0),
        Member(theta=[0.9, 0.9], h=[1,0], _id=1)
        ]
    member_ids = np.arange(len(population))
    # begin training
    for _ in range(75):
        np.random.shuffle(member_ids)
        for mem_id in member_ids:
            member = population[mem_id]
            member.step()
            member.eval()
            if member.ready() and not grid:
                if explore_only:
                    member.explore()
                else:
                    member.exploit(population[(mem_id + 1) % 2])
                    if not exploit_only:
                        member.explore()
                    member.eval()
            member.log()

    traces = [population[i].trace for i in range(len(population))]
    ps = [population[i].ps for i in range(len(population))]
    return traces, ps

Let’s plot the four modes

f, axes = plt.subplots(2, 4)#, sharex=True)

traces, ps = pbt(grid=True)
plot_traces(axes[0,0], traces, title='Grid Search')
plot_curves(axes[0,2], ps, title='Grid Serach')

traces, ps = pbt(exploit_only=True)
plot_traces(axes[0,1], traces, title='Exploit Only')
plot_curves(axes[0,3], ps, title='Exploit Only')

traces, ps = pbt(explore_only=True)
plot_traces(axes[1,0], traces, title='Explore Only')
plot_curves(axes[1,2], ps, title='Explore Only')

traces, ps = pbt()
plot_traces(axes[1,1], traces, title='PBT')
plot_curves(axes[1,3], ps, title='PBT')

plt.tight_layout()
plt.show()

which results in the following: