Table of contents

DeepSeek R1: GRPO in action – A battlefield analogy for next-gen LLMs
LLM pre- and post-training
The Battlefield: Battleship as a simulation environment
Setting Up the Game Environment
The Environment and the Policy
Introducing GRPO: the group tactics
Training the policy: the final push
Parallels to DeepSeek R1’s efficiency
Conclusion

Table of contents

Table of contents

DeepSeek R1: GRPO in action – A battlefield analogy for next-gen LLMs
LLM pre- and post-training
The Battlefield: Battleship as a simulation environment
Setting Up the Game Environment
The Environment and the Policy
Introducing GRPO: the group tactics
Training the policy: the final push
Parallels to DeepSeek R1’s efficiency
Conclusion

DeepSeek R1: GRPO in action – A battlefield analogy for next-gen LLMs

DeepSeek R1: GRPO in action – A battlefield analogy for next-gen LLMs

28 Mar 2025

What if training powerful AI models could be faster, cheaper, and more efficient? DeepSeek R1’s GRPO is changing the game, cutting memory and compute costs nearly in half. Through a Battleship-inspired simulation, learn how this breakthrough is reshaping Reinforcement Learning.

DeepSeek R1 has set the AI community abuzz, and today we’re diving deep into one of its secret weapons: Group Relative Policy Optimization (GRPO), as introduced by the DeepSeekMath paper. Unlike traditional PPO (Proximal Policy Optimization), GRPO brings a fresh, battle-hardened approach to Reinforcement Learning, slashing memory and compute overhead by nearly 50%, as well as significantly reducing the training cost. In this post, we’ll unpack a Python code example that simulates GRPO on a virtual Battleship battlefield and draw parallels to how DeepSeek R1 uses the same principles to revolutionize LLM training. In less than three hundred lines of code, we demonstrate using pure NumPy how GRPO works and what its added value is.

LLM pre- and post-training

To understand why DeepSeek R1 is so cost-effective to train, let us recall how an LLM is trained from a high-level perspective. The training process of a Large Language Model (LLM) consists of two main phases: pre-training and post-training.

1. Pre-training: this is the initial phase where the model is exposed to vast amounts of high-quality data to learn next-token prediction. Essentially, it’s about feeding the model large-scale data and allowing it to develop a foundational understanding of language. Given the significant computational cost, this phase is typically carried out by large companies, while smaller companies often focus more on post-training.

2. Post-training: this phase focuses on refining the model’s reasoning abilities and is typically divided into two key steps:

  • Stage 1 – supervised fine-tuning: here, the model is fine-tuned using a smaller, high-quality dataset curated by experts. The goal is to teach the model to follow instructions, answer questions, and perform structured reasoning (e.g., chain-of-thought reasoning). Ideally, if unlimited expert data were available, this would be the best way to enhance the model. However, since high-quality data is scarce, an additional step is needed.

  • Stage 2 – Reinforcement Learning from Human Feedback (RLHF): Since expert reasoning data is limited, reinforcement learning (RL) is employed to bridge the gap. RLHF involves training a reward model based on human feedback, which then helps guide the LLM’s learning through reinforcement learning. This process ensures the model better aligns with human preferences, making it more effective and reliable in real-world applications.

RLHF is an expensive step in LLM training due to the high cost of human annotation, where expert annotators rank or critique model outputs. It also requires intensive compute resources for training a reward model and running Proximal Policy Optimization (PPO), which involves multiple inference steps and policy updates.

DeepSeek’s revolutionary idea is to skip the supervised fine-tuning part and to apply reinforcement learning directly to their DeepSeek V3 foundational model, thereby significantly saving cost and eliminating human bias present in the curated dataset. Furthermore, they introduced the GRPO algorithm as a replacement of the PPO algorithm. In order to keep the discussion away from deep technical details, we will apply the GRPO algorithm to the Battleship to explain how it works. 

The Battlefield: Battleship as a simulation environment

Imagine you’re commanding a fleet in a high-stakes Battleship game. Every move - each missile fired - is critical. In our simulation, the BattleshipGame class sets the stage with:

  • A 5×5 board where ships (“🚢”) are hidden beneath waves (“🌊”).

  • A dynamic environment where each missile can either miss (“💦”) or hit (“💥”), affecting your overall score.

This battlefield isn’t just for fun - it’s a metaphor for how GRPO treats each decision (or token generation) as a tactical move in an evolving campaign.

@dataclass
class BattleshipGameRules:
    board_size: int = 5  # A 5 x 5 game board
    ships: tuple[int, ...] = (0, 0, 1, 1, 1, 0)  # 1x destroyer, 1x cruiser/submarine, 1x battleship


@dataclass
class BattleshipGame:
    board: BattleshipGameBoard
    rules: BattleshipGameRules = field(default_factory=BattleshipGameRules)

    @staticmethod
    def random_board(rules: BattleshipGameRules | None = None, seed: int | None = None) -> BattleshipGameBoard:
        rules = rules or BattleshipGameRules()
        random_state = np.random.RandomState(seed)
        ships = [ship_size for ship_size, ship_count in enumerate(rules.ships) for _ in range(ship_count)]
        ships_placed = False
        while not ships_placed:
            board = np.full((rules.board_size, rules.board_size), "🌊", dtype=np.str_)
            for ship_size in ships:
                if ship_size > rules.board_size:
                    return
                ship_top_left = random_state.randint(low=0, high=rules.board_size - (ship_size - 1), size=2)
                ship_bottom_right = ship_top_left + 1
                ship_bottom_right[random_state.randint(low=0, high=2)] += ship_size - 1
                if np.all(board[ship_top_left[0] : ship_bottom_right[0], ship_top_left[1] : ship_bottom_right[1]] == "🌊"):
                    board[ship_top_left[0] : ship_bottom_right[0], ship_top_left[1] : ship_bottom_right[1]] = "🚢"
                else:
                    break
            else:
                ships_placed = True
        return board

    def play(self, fire: tuple[int, int]) -> bool:
        hit = self.board[fire] in ("🚢", "💥")
        self.board[fire] = "💥" if hit else "💦"
        return hit

    def score(self) -> float:
        done = not np.any(self.board == "🚢")
        efficiency: float = 1.0 - np.sum(self.board == "💦") / (self.board.size - np.sum(self.board == "💥") + 1) if done else 0.0
        return efficiency

    def __repr__(self) -> str:
        return "\n".join("".join(row) for row in self.board)

Setting Up the Game Environment

The simulation begins with a random board generator that places ships on the grid. The game rules and board setup mirror the uncertain, dynamic nature of real-world decision-making, where not every move leads to a hit.

  • Random Board Generation: 

    The random_board method continuously attempts to place all ships without overlapping until a valid board configuration is achieved. This randomness is akin to the initial uncertainty in model responses before optimization.

  • Game Mechanics:

    The play and score methods simulate taking an action (firing a missile) and evaluating the outcome, respectively. A hit boosts your score, while a miss leaves a mark but no reward.

The Environment and the Policy

Our BattleshipEnv class wraps the game logic into a reinforcement learning environment:

  • Observation Encoding:
    The board is encoded into a numerical format (e.g., fog of war, hit, or miss), which serves as the “state” the policy observes.

class BattleshipEnv(Environment):
    rules = BattleshipGameRules()
    max_steps = rules.board_size**2

    def __init__(self, init_seed: int | None = None, step_seed: int | None = None) -> None:
        self.state = BattleshipGame(board=BattleshipGame.random_board(self.rules, init_seed), rules=self.rules)
        self.random_state = np.random.RandomState(step_seed)

    @property
    def observation(self) -> ObservationArray:
        # 0 = fog of war, -1 = missile miss, 1 = missile hit
        encoded_board = np.zeros(self.state.board.shape, dtype=np.float32)
        encoded_board[self.state.board == "💦"] = -1.0
        encoded_board[self.state.board == "💥"] = 1.0
        return encoded_board

    @classmethod
    def reset(cls, init_seed: int | None = None, step_seed: int | None = None) -> tuple["BattleshipEnv", ObservationArray]:
        env = cls(init_seed, step_seed)
        return env, env.observation

    def sample_action(self, action_proba: ActionProbaArray) -> int:
        # Mask out illegal actions.
        illegal_actions = np.ravel(self.observation != 0.0)
        action_proba[illegal_actions] = 0.0
        action_proba /= np.sum(action_proba)
        # Sample an action from the probability distribution.
        action = int(self.random_state.choice(len(action_proba), p=action_proba))
        return action

    def step(self, action: int) -> tuple[ObservationArray, float, bool]:
        self.state.play(fire=divmod(action, self.state.rules.board_size))
        reward = self.state.score()
        done = reward > 0.0
        return self.observation, reward, done

  • Policy Model:
    A simple neural network (neural_battleship_policy) maps these observations to action probabilities - essentially deciding where to fire next.

def neural_battleship_policy_init(rules: BattleshipGameRules | None = None, seed: int = 42) -> ParamsDict:
    rules = rules or BattleshipGameRules()
    num_tiles = rules.board_size**2
    random_state = np.random.RandomState(seed)
    scale = np.sqrt(2.0 / (2 * num_tiles))  # Xavier/Glorot initialization
    params = {
        "W1": random_state.normal(scale=scale, size=(num_tiles, num_tiles)).astype(np.float32),
        "b1": np.zeros(num_tiles, dtype=np.float32),
        "W2": random_state.normal(scale=scale, size=(num_tiles, num_tiles)).astype(np.float32),
        "b2": np.zeros(num_tiles, dtype=np.float32),
    }
    return params


def neural_battleship_policy(params: ParamsDict, observation: ObservationArray) -> ActionProbaArray:
    # A simple feedforward neural network with a single hidden layer.
    x = np.ravel(observation)
    h = np.tanh(params["W1"] @ x + params["b1"])
    logits = params["W2"] @ h + params["b2"]
    logits -= np.max(logits)  # Softmax is invariant to shifting the logits.
    exp_logits = np.exp(logits)
    softmax = exp_logits / np.sum(exp_logits)
    return softmax


def reference_battleship_policy(observation: ObservationArray) -> ActionProbaArray:
    # Fire on any fog of war tile with uniform probability.
    p = np.ravel((observation == 0.0).astype(np.float32)) + np.sqrt(np.finfo(np.float32).eps)
    p = p / np.sum(p)
    return p

Notice how our simulation’s policy mirrors an LLM’s decision-making process. Just as the network picks the next token based on context, our policy selects the next move based on the board’s state.

Introducing GRPO: the group tactics

Here’s where GRPO takes center stage. Traditional PPO relies on a critical network to evaluate each step, but GRPO does things differently by harnessing group sampling.

@dataclass
class GRPOConfig:
    environment: type[Environment]
    policy: PolicyFunction
    reference_policy: ReferencePolicyFunction

    ε: float = 0.9  # Policy ratio clip epsilon
    ß: float = 0.0  # Weight for KL divergence between the policy and the reference policy
    G: int = 16  # Number of trajectories per group
    B: int = 4  # Number of groups per mini-batch
    M: int = 2048  # Number of mini-batches to train on
    μ: int = 10  # Number of gradient steps per mini-batch


def collect_group(policy_params: ParamsDict, grpo_config: GRPOConfig, env_seed: int | None = None) -> Group:
    # Initialize the group output.
    group_observations: list[list[ObservationArray]] = [[] for _ in range(grpo_config.G)]
    group_actions = [np.empty(grpo_config.environment.max_steps, dtype=np.intp) for _ in range(grpo_config.G)]
    group_actions_proba = [np.empty(grpo_config.environment.max_steps, dtype=np.float32) for _ in range(grpo_config.G)]
    group_rewards = np.zeros(grpo_config.G, dtype=np.float32)
    # Create a fixed environment initialization seed.
    init_seed = env_seed if env_seed is not None else np.random.randint(2**32)
    # Generate trajectories starting from the initial environment.
    for group in range(grpo_config.G):
        # Start a new environment (a game) from a fixed initial seed.
        env, observation = grpo_config.environment.reset(init_seed=init_seed, step_seed=init_seed * group)
        for step in range(env.max_steps):
            # Evaluate the policy model to obtain the action probability distribution.
            action_proba = grpo_config.policy(policy_params, observation)
            # Sample an action from the policy's action probability distribution.
            action = env.sample_action(action_proba)
            # Update the group output.
            group_observations[group].append(observation)
            group_actions[group][step] = action
            group_actions_proba[group][step] = action_proba[action]
            # Advance the environment with the sampled action.
            observation, reward, done = env.step(action)
            # Check if this trajectory is done.
            if done:
                group_rewards[group] = reward  # GRPO only considers the terminal reward.
                break
    # Compute the GRPO advantages across the group, but assign them to the actions within each trajectory.
    group_advantages = (group_rewards - np.mean(group_rewards)) / max(np.std(group_rewards), np.finfo(np.float32).eps)
    return (group_observations, group_actions_proba, group_actions, group_rewards, group_advantages)


def grpo_objective(policy_params: ParamsDict, group: Group, grpo_config: GRPOConfig) -> float:
    # For each trajectory in the given group...
    grpo = 0.0
    for observations, actions_proba, actions, _, advantage in zip(*group):
        # ...accumulate the trajectory's step contributions to the GRPO objective.
        for observation, π_θ_t_old, action in zip(observations, actions_proba, actions):
            π_θ_t = grpo_config.policy(policy_params, observation)[action]
            π_ref_t = grpo_config.reference_policy(observation)[action]
            ratio = π_θ_t / π_θ_t_old
            clipped_ratio = np.clip(π_θ_t / π_θ_t_old, 1 - grpo_config.ε, 1 + grpo_config.ε)
            grpo += min(ratio * advantage, clipped_ratio * advantage) / len(actions)  # Advantage
            grpo += -grpo_config.ß * (π_ref_t / π_θ_t - np.log(π_ref_t / π_θ_t) - 1) / len(actions)  # KL divergence
    grpo /= grpo_config.G
    grpo = -grpo  # Flip the sign to turn the maximization problem into a minimization problem.
    return grpo

  • Collecting groups:
    The collect_group function simulates multiple game trajectories (battles) from a fixed starting point. For each trajectory in a group, we record:

    • Observations: The state at each step.

    • Actions and their probabilities: What move was chosen and how confident the model was.

    • Rewards: Only the terminal reward matters, emphasizing end-game success.


  • Group-normalized advantage:
    Instead of estimating the advantage of each action via a critic, GRPO normalizes the terminal rewards within a group:

Advantage = (Reward – Mean Reward) / (Standard Deviation + ε)

  • This strategy, much like comparing the performance of different battalions in a campaign, allows the model to gauge which trajectories performed best relative to the others.


  • GRPO objective function:
    The grpo_objective function applies a clipped surrogate loss – a safeguard against overzealous updates. It balances:

    • Policy improvement: Ensuring actions that yield high advantages get reinforced.

    • KLpPenalty: Keeping the new policy close to a trusted reference (our baseline battleship strategy).

  • The overall objective is then minimized via gradient descent, guided by the custom AdamW optimizer. In effect, the policy is “trained in the trenches” to make smarter moves in subsequent battles.

Training the policy: the final push

The training loop (train_grpo) simulates many mini-batches of battles:

  • Group sampling: For each mini-batch, groups of trajectories are collected.

  • Gradient steps: Multiple updates are performed per mini-batch, refining the policy to maximize terminal rewards.

class AdamWOptimizer:
    def __init__(self, params: ParamsDict, learning_rate: float = 3e-4, ß1: float = 0.9, ß2: float = 0.999, ε: float = 1e-8, λ: float = 0.01) -> None:
        self.params = params
        self.learning_rate = learning_rate
        self.ß1 = ß1
        self.ß2 = ß2
        self.ε = ε
        self.λ = λ
        self.t = 1
        self.state = {key: {"m": np.zeros_like(value), "v": np.zeros_like(value)} for key, value in params.items()}

    def step(self, grad: ParamsDict) -> None:
        for key in self.params:
            self.state[key]["m"] = self.ß1 * self.state[key]["m"] + (1 - self.ß1) * grad[key]
            self.state[key]["v"] = self.ß2 * self.state[key]["v"] + (1 - self.ß2) * (grad[key] ** 2)
            m_hat = self.state[key]["m"] / (1 - self.ß1**self.t)
            v_hat = self.state[key]["v"] / (1 - self.ß2**self.t)
            update = self.learning_rate * (m_hat / (np.sqrt(v_hat) + self.ε) + self.λ * self.params[key])
            self.params[key] -= update
        self.t += 1


def train_grpo(optimizer: AdamWOptimizer, grpo_config: GRPOConfig) -> tuple[ParamsDict, RewardArray]:
    # Define the GRPO objective for a mini-batch of groups of trajectories.
    grpo_objective_batch = lambda policy_params, groups, grpo_config: sum(grpo_objective(policy_params, group, grpo_config) for group in groups)  # noqa: E731
    # Define the gradient of the GRPO objective w.r.t. the policy parameters (the first argument of grpo_objective).
    grpo_objective_batch_grad = grad(grpo_objective_batch)
    rewards_val = np.zeros(grpo_config.M, dtype=np.float32)
    for iter in (pbar := tqdm(range(grpo_config.M))):
        # Collect a mini-batch of groups of trajectories to learn from.
        groups = [collect_group(optimizer.params, grpo_config, env_seed=(iter + 1) * 128 + i) for i in range(grpo_config.B)]
        # Optimize the GRPO objective determined by the current mini-batch for a few steps.
        for _ in range(grpo_config.μ):
            # Compute the gradient and update the solution.
            optimizer.step(grpo_objective_batch_grad(optimizer.params, groups, grpo_config))
        # Track progress of the validation reward.
        groups_val = [collect_group(optimizer.params, replace(grpo_config, G=8), env_seed=i) for i in range(64)]
        rewards_val[iter] = sum(np.mean(group_val[3]) for group_val in groups_val) / len(groups_val)
        pbar.set_description(f"reward_val={rewards_val[iter]:.3f}")
    return optimizer.params, rewards_val

Parallels to DeepSeek R1’s efficiency

So how does all this tie back to DeepSeek R1? Here are the key takeaways:

  • Efficiency through simplicity:
    GRPO’s use of group-normalized advantages eliminates the need for a bulky critic network. DeepSeek R1 leverages this by training its LLM directly with RL, bypassing the supervised fine-tuning stage and saving substantial compute time.


  • Robust exploration:
    By sampling multiple trajectories per prompt (or battle), GRPO inherently encourages exploration. This mirrors DeepSeek R1’s strategy of letting the model “self-evolve” its reasoning capabilities through massive parallel sampling.


  • Stability and performance:
    The clipped surrogate loss in GRPO ensures that policy updates are controlled and stable – crucial for maintaining performance in the high-stakes, compute-intensive world of LLM training.

DeepSeek R1’s adoption of GRPO is a testament to how borrowing ideas from other fields (like battlefield tactics) can lead to revolutionary improvements in AI. By drawing on the strategic, group-based sampling approach, DeepSeek R1 not only streamlines RLHF but also paves the way for more efficient and robust LLMs.

Conclusion

The battle-tested GRPO algorithm is more than just a clever trick – it’s a paradigm shift in how we optimize policies for large language models. With DeepSeek R1 harnessing GRPO’s simplicity and efficiency, the future of RL in LLM training looks both promising and thrilling. Whether you’re a seasoned RL expert or just stepping into the arena, the elegance of GRPO’s group strategy is a powerful reminder: sometimes, winning the war is all about how you fight the individual battles.

Dive into the code, explore the tactics, and join the revolution with DeepSeek R1 – where every move counts. Curious about how to easily and efficiently build a RAG pipeline using DeepSeek as a generator? Then have a look at our RAGLite series of blogposts: introduction and guided walkthrough.

Ready to explore more on the cutting edge of LLM training? Stay tuned for our next deep dive into the latest innovations in reinforcement learning and AI development!

Author(s):

Renaud Chrétien

Machine Learning Engineer

ai for agriculture and food systems

Article

Discover how AI is revolutionizing agriculture and food supply chains—boosting sustainability, cutting waste, and optimizing resources. From precision farming to smarter logistics, AI is shaping a greener future for food system.

ai for agriculture and food systems

Article

Discover how AI is revolutionizing agriculture and food supply chains—boosting sustainability, cutting waste, and optimizing resources. From precision farming to smarter logistics, AI is shaping a greener future for food system.

ai for agriculture and food systems

Article

Discover how AI is revolutionizing agriculture and food supply chains—boosting sustainability, cutting waste, and optimizing resources. From precision farming to smarter logistics, AI is shaping a greener future for food system.

Article

In an increasingly complex and competitive world, businesses must streamline operations to stay ahead. Optimizing your value chain through AI, automation, and smart decision-making can drive efficiency, cut costs, and enhance resilience.

Article

In an increasingly complex and competitive world, businesses must streamline operations to stay ahead. Optimizing your value chain through AI, automation, and smart decision-making can drive efficiency, cut costs, and enhance resilience.

Article

In an increasingly complex and competitive world, businesses must streamline operations to stay ahead. Optimizing your value chain through AI, automation, and smart decision-making can drive efficiency, cut costs, and enhance resilience.

multimodal rag system of elements linked in a network

Article

Explore the next evolution of Retrieval-Augmented Generation (RAG), where AI goes beyond text to integrate images, video, and audio. Multimodal RAG unlocks richer, more precise insights, but merging diverse data comes with challenges.

multimodal rag system of elements linked in a network

Article

Explore the next evolution of Retrieval-Augmented Generation (RAG), where AI goes beyond text to integrate images, video, and audio. Multimodal RAG unlocks richer, more precise insights, but merging diverse data comes with challenges.

multimodal rag system of elements linked in a network

Article

Explore the next evolution of Retrieval-Augmented Generation (RAG), where AI goes beyond text to integrate images, video, and audio. Multimodal RAG unlocks richer, more precise insights, but merging diverse data comes with challenges.

Contact Us

Ready to tackle your business challenges?

Stay Informed

Subscribe to our newsletter

Get the latest AI insights and be invited to our digital sessions!

Stay Informed

Subscribe to our newsletter

Get the latest AI insights and be invited to our digital sessions!

Stay Informed

Subscribe to our newsletter

Get the latest AI insights and be invited to our digital sessions!

Locations

Brussels HQ

Central Gate

Cantersteen 47



1000 Brussels

Ghent

Planet Group Arena

Ottergemsesteenweg-Zuid 808 b300

9000 Gent

© 2024 Superlinear. All rights reserved.

Locations

Brussels HQ

Central Gate

Cantersteen 47



1000 Brussels

Ghent

Planet Group Arena
Ottergemsesteenweg-Zuid 808 b300
9000 Gent

© 2024 Superlinear. All rights reserved.

Locations

Brussels HQ

Central Gate

Cantersteen 47



1000 Brussels

Ghent

Planet Group Arena
Ottergemsesteenweg-Zuid 808 b300
9000 Gent

© 2024 Superlinear. All rights reserved.