The Principled Agent

The journey to a better policy.


Can You See What I See – Breakout Baseline #2

The biggest problem with tuning my RL agents is that I’m flying blind. How can I possibly turn the right knobs when my only instruments are a scrolling loss value and a final score?

Today, we’re upgrading the cockpit with a full instrument panel. Let’s visualize what our agent is actually thinking and use that information to tune our agent.

Where We Left Off

If you missed the last post, here is a quick rundown:

I finished coding a JAX based PPO Actor-Critic agent to play the MinAtar Breakout environment. And learned JAX along the way! After a deep dive into log_softmax and why it is useful, I found the agent was performing quite poorly. We were getting 3.98 average episodic reward compared to the baseline’s 28. Just dredful.

Usually, I would spend a few hours guessing what was wrong, tweaking hyperparameters, potentially changing major portions of code, and hoping I would stumble on to a fix.

But the whole point of this blog to find a better way. So, today, we’re doing it differently. Today, we’re going to visualize everything we can think of to form a hypothesis, visualize more to increase our certainty, and only then make changes to our agent!

Watcha Doin’?

Okay, the first, and perhaps most obvious thing we can visualize is an actual rollout of the environment. After a quick Google search, OpenCV seems to be a good way to generate videos in Python. Never used it so… more learning!

I remember watching the AlphaStar exhibition back in 2019 and, if memory serves, they had a video of the actual environment as the agent played so they could see what the agent was seeing. An on the right of that playback, they had some information panels for further insights into the agent. Let’s recreate that. I think as far as information goes, perhaps the action probabilities and the value estimation would be good starting points. This covers our two networks at a high level.

So after some fiddling with OpenCV and Matplotlib, I’ve got something up and running. Wanna see the agent in a full rollout after training? Me too.

🔍Analysis

Some of the well-practiced of you might already recognize the problem. I certainly didn’t see the signs until just rewatching it as I upload it here.

The agent seems to be moving around, it can kind of track the ball. The action probabilities are jumping around so it is not stuck on a single action.

The value function is stuck at a fixed point. Let’s keep an eye on that.

Hmm, things clearly don’t look good. But I don’t think we have enough information yet to make any hypotheses.

Watcha Lookin’ At?

Okay just watching the agent perform actions isn’t helping me much. Maybe it would be good to see what the agent is focused on, though? The agent could have easily gotten that final ball it missed if they knew it was there so why didn’t they move the paddle? Are they just looking in the wrong place? (I’m going to be anthropomorphizing agents a lot here so be prepared)

Okay, another Google search and I’ve found a good candidate for visualizing what the agent is focusing on. I went down a rabbit hole of something called Grad-CAM before realizing it is specific to Convolutional Neural Networks (CNNs) and we’re using MLPs. If using MLPs instead of CNNs for pixel input drives you crazy, be patient. We’ll get there.

Eventually, I landed on using a vanilla gradient salience map. We’re going to calculate the gradient of the policy networks actions with regards to the input image. In essence, we’re finding what pixels in the input, if slightly shifted, would have the biggest impact on the final action.

Let’s superimpose this salience map on our video of the agent playing and see what the agent is looking at! The darker the blue a pixel, the more focus our agent has on that spot.

🔍Analysis

Seems fine, right? The agent is looking at the path of the ball as it follows the diagonal trajectory. Wait a second… it’s just focused on two lanes, isn’t it? Watch it again. There are two 3-pixel-wide diagonal lanes it cares about. Outside of those, it’s totally blind.

Well this is just one rollout. Maybe the agent is doing fine in other rollouts with balls outside these two lanes. We should double check before we jump to any conclusions, right?

Where’d You See It Last?

I know. I’m stretching the theme of these section headers a little bit. So, I’ve noticed something is off in the single rollout with the diagonal lanes. I’m curious if the agent is always failing in the same way. Or if this is just a one-off that’s going to send me down a not-so-useful path.

A creator named Peter Whidden recently made a YouTube video of an RL agent that played Pokemon Red. In that video, as part of their debugging process, they created a heatmap to show all the locations where the agent was failing. I think we can steal that idea to get a feel for whether these diagonal lanes are a real problem.

Let’s run our agent 1000 times and, for each pixel, count the number of times the ball or paddle is in this pixel the frame before the agent loses. We can then generate a heatmap from that to see the most likely locations of important objects in the screen when the agent loses.

Okay, that was pretty simple to throw together with Matplotlib. Let’s check it out:

🔍Analysis

Well that certainly seems like a clear sign the agent likes those lanes a lot. But why? They should be happy to learn to handle such a small deviation from the lanes where they succeed.

The agent doesn’t seem to be able to handle the changing dynamics of the environment as the game progresses.

Okay, so… we have an agent that can handle the initial part of the game pretty well, has learned to watch the ball early on, but as soon as the dynamics change a little, loses immediately. I think I have an idea of what is going wrong here.

The agent seems to be converging to a suboptimal policy very early instead of continuing to explore as the environment changes. So, we should increase our entropy coefficient from our PPO loss to increase our exploration! We get to use our first Hypothesis block:

🤔Hypothesis

If we increase the entropy coefficient in our PPO loss, the agent will get higher average episodic rewards after training because it will not prematurely converge to a suboptimal policy and will be able to better explore the environment.

Why You Acting Weird?

Awesome, we have a direction and a hypothesis. So, what is this entropy coefficient in our PPO loss you may be asking?

In RL, a naively coded agent will tend to find a good-enough policy and stick to it rather than learning new, perhaps better policies. We don’t want our agent to get stuck in these sub-optimal policies; we want to force them to explore a little bit.

Enter: Entropy bonuses. When calculating the loss of our agent, we can give the agent a little reward if they’re not very certain about their action choice (high entropy). Then, because we’re taking a weighted sample for our agent’s action choice, we’ll occasionally sample some action besides their favorite. This, in essence, is a type of exploration.

This strength of this entropy bonus is controlled by a coefficient. Too high, meaning you reward the agent a lot for being uncertain, and your agent never learns anything because they get enough reward just having random actions. Too low, like what I suspect we’re seeing now, and your agent doesn’t explore at all and gets stuck in a sub-optimal policy.

So, let’s see if my hypothesis is right and let’s tweak the entropy. If I change the entropy coefficient and our agent’s performance doesn’t change, we’ll know my hypothesis was wrong. Right now, I have it set at \displaystyle  0.1.

Let’s try \displaystyle  0.5 and run it. Okay, done! Hmm… the agent seems to be doing the exact same thing. I had it set way too low!

How about \displaystyle  1? Still nothing. Hmm…

How about \displaystyle  10? \displaystyle  100?

Okay the agent looks exactly the same for all these coefficients. Something is wrong with my loss!

What’s buggin’ you?

Okay, let’s look at my loss definition:

def policy_loss(self, logits: jnp.ndarray, old_log_probs: jnp.ndarray, advantages: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray:
        one_hot_actions = jax.nn.one_hot(actions, num_classes=logits.shape[-1])

        log_probs = jax.nn.log_softmax(logits)
        log_probs = jnp.sum(log_probs * one_hot_actions, axis=-1)

        ratio = jnp.exp(log_probs - old_log_probs)

        loss = jnp.minimum(ratio * advantages, jnp.clip(ratio, 1 - self._clip_epsilon, 1 + self._clip_epsilon) * advantages)
        loss = -jnp.mean(loss)

        return loss


    def entropy_bonus(self, logits: jnp.ndarray) -> jnp.ndarray:
        probs = nn.softmax(logits)
        entropy_per_disribution = -jnp.sum(probs * jnp.log(probs + 1e-8), axis=-1)
        return jnp.mean(entropy_per_disribution)


    def ppo_loss(self, logits: jnp.ndarray, old_logits: jnp.ndarray, advantages: jnp.ndarray, actions: jnp.ndarray) -> jnp.ndarray:
        return self.policy_loss(logits, old_logits, advantages, actions) - self._entropy_coefficient * self.entropy_bonus(logits)

And where I’m using the loss in my code:

def actor_loss_fn(actor_params):
            new_logits = actor.apply({'params': actor_params}, trajectory.obs.reshape(-1, 10, 10, 4))
            return ppo_impl.policy_loss(new_logits, trajectory.log_prob.flatten(), advantages.flatten(), trajectory.action.flatten())

See the issue? It took me a second.

Yup, I’m using policy_loss which does not include the entropy bonus at all! I should be caling ppo_loss instead. That explains why my agent isn’t exploring at all and also why changing the entropy coefficient had no effect.

Okay, fixed. I’m calling the right loss function now. Let’s set the entropy coefficient back to our original value, \displaystyle  0.1, and see how we’re doing.

Results

First, lets just run it and see what our average episode rewards are. Across 1024 episodes, the average episodic reward is 9.81.

Hey! That’s way better! We’ve basically doubled our episodic rewards by fixing that bug and introducing the entropy bonus into our loss calculation! We’re still a far cry from the 28 that the GitHub boasts but we’ve made a major step.

Let’s watch the agent and see what it looks like:

🔍Analysis

The agent’s attention is now quite chaotic. This is a far cry from the rigid tunnel vision we saw before. Our new agent is constantly challenging its own policy. You can see in the action probabilities that it’s much less certain, which is exactly the exploratory behavior we wanted to encourage. These are all signs of a much healthier learning process.

A score of 9.81 is still a long way from the baseline of 28, but I’m thrilled. Why? Because this isn’t just a higher score. It’s proof of concept. It proves that by moving from blind guesswork to methodical visualization, we can hunt down subtle bugs and make real, measurable progress. We didn’t just fix a line of code. We validated a better way of working.

Of course, fixing this bug has made it clear just how critical that one little entropy term is. It’s the first knob we’ve turned with purpose, and it immediately raises new questions. How do you find the correct value for the entropy coefficient? What does the trade-off between exploration and exploitation actually look like?

That’s exactly what we’ll tackle next time, as we do a proper deep dive into the art of tuning entropy.

Code

You can find the complete, finalized code for this post on the V1.1 Release on GitHub.



Leave a comment