← Back to Home

Technical Implementation

Detailed documentation of the custom RL environments, training infrastructure, and experimental setup.

Training Configuration

Parameter Value
Base Model Qwen-30B-A3B
Algorithm PPO/GRPO
LoRA Rank 32
Batch Size 1,204
Prompts per Step 64
Responses per Prompt 16
Total Tokens 188 million
Compute Cost $44.50
Training Dynamics: Rewards did not stably increase over time, possibly due to model scale or needing additional training steps. We observed a pattern where the model would flip from being equal or worse on tasks to slightly better at intermediate checkpoints.

Training Architecture

Our training harness relies on an Async Off-Policy Reinforcement Learning loop built on top of the Tinker platform. This architecture decouples data generation (rollouts) from model optimization, maximizing GPU utilization.

The "Streaming Minibatch" Loop

The core logic resides in continuous_runner.py, which orchestrates the interaction between two key client types:

  • Sampling Client: A lightweight client responsible for generating text and computing log-probabilities. It pulls weights from the latest checkpoint to ensure data is "fresh" (on-policy or near-on-policy).
  • Training Client: A heavy client that computes gradients and updates the LoRA parameters. It runs continuously in the background.

This setup uses a "Streaming Minibatch" strategy. Instead of waiting for a full dataset to be collected before training (Synchronous PPO), we feed small batches of trajectories into the optimizer as soon as they are ready. This reduces the "time-to-first-update" and creates a smoother learning curve.

# continuous_runner.py (simplified) async def _timed_async_training(...): while i_batch < end_batch: # 1. Asynchronously collect trajectories wrapped_traj=await trajectory_groups_queue.get() # 2. Filter stale data (too far off-policy) if not filter_stale(wrapped_traj): continue # 3. Stream to the optimizer new_sampler, metrics=await train.do_train_step_streaming( ..., trajectory_groups_queue, training_client ) # 4. Update the sampler to the new weights state["sampling_client"]=new_sampler

Off-Policy Staleness Control

We limit off-policy effects by enforcing a maximum staleness threshold on concurrently generated trajectories. Overly old rollouts are discarded and the corresponding tasks are re-issued to maintain policy freshness.

Verifiers Adapter

To integrate the verifiers library (designed for standard OpenAI APIs) with Tinker's custom RPC protocol, we implemented a VerifiersAdapter. This component acts as a translation layer.

Crucially, it injects a Generation Hook to capture token-level log-probabilities during generation. Standard OpenAI-compatible clients often discard this granular data, but our reward functions rely on it to compute exact probabilities of target words.

# verifiers_adapter.py async def custom_do_group_rollout(builder, policy): # Fix: Instantiate deep copies for stateful environments envs = [copy.deepcopy(vf_builder.vf_env) for _ in range(group_size)] # Run rollouts in parallel results = await asyncio.gather(*[run_one_rollout(env) for env in envs]) return TrajectoryGroup(results...)

Custom Environments: Deep Dive

We developed five distinct RL environments where rewards depend on quantities computed from token-level likelihoods or ground-truth-verifiable self-predictions, rather than human preferences.


1. Latent Code Transmission (Ghost Trace)

Goal

Transmit a target semantic concept (a word drawn from a fixed concept bank) using a fixed-length integer code such that the concept becomes likely under a standardized decoding prompt.

Setup

The model is given a target word and must emit a fixed-length sequence of integers within a specified range (5 integers in [0, 999]), with a strict output format. The harness parses the integers and inserts them into a fixed decoding template ("Sequence: ... Guess the object:").

Reward

The target is the original word. The harness computes reward from the model's log-probability assigned to the target word tokens under the decoding prompt containing the emitted code:

R = (1/k × Σ log P(tokeni | ListenerPrompt)) + 10.0

We use the compute_logprobs method on the Tinker Sampling Client to get the exact likelihood of the target tokens given the "number prompt".


2. Proper-Scoring-Rule Confidence Reporting

Goal

Answer a labeled synthetic question and report a confidence value c ∈ [0,1] that is incentivized to be calibrated under a strictly proper scoring rule.

Setup

The model answers a question and emits a confidence line (CONFIDENCE: <float>) in a strict format. The harness computes a binary correctness label y ∈ {0,1}.

Reward

Reward combines format validity, task accuracy, and a calibration term computed using the Brier formulation:

Rcal = 1 − (c − y)²

This is a strictly proper scoring rule for probabilistic forecasts—maximized only when the predicted probability matches true empirical frequency.


3. Enumerated-Set Entropy Estimation

Goal

Select a valid discrete response from a constrained set and report an estimate of the flatness of the model's own logprobs that can be validated against the true logprob distribution.

Setup

Each prompt defines a finite set of valid discrete responses (integers satisfying rules like "primes only" or "multiples of 5", or numbers from 1 to 100). The model outputs both a discrete choice and a normalized entropy estimate.

Target Signal

We check the log-probabilities for each item in the valid set, normalized into a probability distribution, and ground truth is the normalized Shannon entropy:

H(P) = −Σ pi × log(pi)

Reward

Reward decreases with absolute entropy-estimation error (implemented as a clipped linear score).


4. Context-Conditioned Likelihood-Shift Prediction

Goal

Predict how prepending auxiliary context changes the model's log-probability of a designated target answer. A separate variant ranks probes by how strongly their predictive distributions change.

Setup

Each instance provides auxiliary context c and one or more probes (qi, ai) with designated target answers. In the in_context variant, the model outputs a single scalar prediction Δ̂ for the log-probability shift.

Target Signal

The harness scores the target answer under two prompts (with and without c) using the scoring API. Ground-truth shift = log P(a|c,q) − log P(a|q).

Reward

Reward is a smooth decreasing function of prediction error (inverse-squared-error style score).


5. Parameter-Update Sensitivity Prediction

Goal

Predict the effect of a single gradient-based update on the model's future behavior, operationalized as the change in log-probability of a held-out probe answer after one update step on a provided training sample.

Setup

Each instance provides (i) a training datum (x,y) and (ii) an independent probe (q,a). The model outputs a scalar prediction Δ̂upd for how the probe likelihood will change.

The Shadow Client

Measuring the effect of a gradient step is destructive—it changes the model. To do this safely during training, we built a Shadow Client:

# gradient_intuition.py async def _measure_true_gradient_effect(self, prompt, answer, probe): # 1. Reset Shadow Client to match current main model await self.shadow_client.reset_weights() # 2. Pre-measurement prob_pre = await self.shadow_client.compute_logprobs(probe) # 3. Perform ONE gradient step on the Shadow Client ONLY await self.shadow_client.train_step(prompt, answer) # 4. Post-measurement prob_post = await self.shadow_client.compute_logprobs(probe) return prob_post - prob_pre

The reward is proportional to the accuracy of the prediction of this prob_post - prob_pre delta.

Note: This environment is substantially more expensive than string-based rewards, requiring additional forward/backward passes and an optimizer step on a model copy. The harness parallelizes reward computation and isolates expensive target computation from the main sampling loop where possible.

Failure Case Analysis

Calibration Tasks: No Improvement

Both confidence-and-accuracy and Brier score calibration tasks showed no significant improvement (p > 0.5), despite using proper scoring rules. The likely cause is sparse gradient signal: the calibration reward Rcal = 1 − (c − y)² provides maximal gradient magnitude at c = 0.5 and diminishes toward c ∈ {0, 1}. Since the base model already produces relatively extreme confidence values, the training signal may have been too weak to induce behavioral change.

Likelihood Prediction: High Variance

In-context learning prediction and surprise-ranking tasks showed high step-to-step reward variance (ranging from −0.7 to +3.4 within consecutive batches). The underlying measurement—log-probability differences—is extremely sensitive to specific prompts, more so than the actual signal we tried to learn.

Replication

We implement an orchestration pipeline using GitHub Actions to manage long-running training on preemptible hardware. It checkpoints the model state via the tinker API and commits execution metadata (curriculum indices, metrics, and checkpoint URIs) back to the repository's version history, which allows training to resume across sessions.

To replicate:

  1. Fork the repository at github.com/SauersML/minds_RL
  2. Set the TINKER_API_KEY secret in your repository settings
  3. Run the GitHub Action workflow files