ULT: Unifying Teacher-Student RL with Transformers

The Student-Teacher Paradigm in RL

The student-teacher way of training RL policies is based off of the fact that not all privileged information is present in the real world. In this setup:

  • The Teacher model learns to roll out better trajectories based on the additional (privileged) info it sees.
  • The Student model learns to imitate the teacher model using only the information available in the real world.

Offline vs. Online Training

This can be done either offline or online:

  • Offline Approach: Whatever the teacher policy learns is stored in a dataset. We then train the student model on that dataset, saving the pairs of inputs (data modalities) and outputs (actions). The input is filtered to only contain the data streams the student model will have access to in the real world.
  • Online Approach: The teacher and the student model are trained together, where the student model learns to imitate the teacher model in real-time.

The Distribution-Shift Problem

During online training, the student uses its own policy to generate new trajectories. This means the student can go “off-distribution” from the teacher policy, known as the distribution-shift problem. To fix this, researchers often use DAgger: the student policy generates trajectories, the teacher policy labels the data points, and then the student policy is retrained on this new augmented dataset.

Replacing Split Policies with ULT

In the ULT paper, the core idea is to replace the split policies (teacher and student) with a single unified transformer architecture.

Transformers are meant to process sequential data and find relations between every component of the input. However, if kept as-is, we would fail the core idea where only the teacher model has access to privileged information.

Causal Masking

The brilliant idea here is to use causal masking. This allows the transformer to only focus on certain parts of the sequence while ignoring the rest, effectively controlling what “privileged” information is visible at what time during the student’s training.

Curriculum Training

A simple yet effective idea often used is curriculum training. By gradually increasing the difficulty of the training, we can make the sim2real transfer more robust, primarily against varied terrain.

Deep Dive: The ULT Architecture

The architecture in the ULT paper houses a plain transformer, but the data processing is key:

  1. Tokenization: The input trajectory (pairs of action/states) is tokenized simply by concatenating the observation-state pairs and passing them through a simple MLP (outputting a 128-dim token embedding).
  2. Separation of Info: Privileged observation is not mixed with proprioceptive information. They are encoded separately:
    • Proprioceptive (encoded): $d$ dim $(H_0, H_1, H_2, H_3 \dots)$
    • Exteroceptive (encoder): $d$ dim $(H_e)$

Sequence Preparation

The input sequence is prepared as: $[H_0, H_1, H_2, H_3 \dots, H_e]$.

By placing the exteroceptive information at the end, the proprioceptive information (the student) does not get access to it yet. This sequence is passed through the transformer’s attention module. To prevent “future looking” and to keep the privileged info separate from the student tokens, causal masking is applied. This ensures proprioceptive observations don’t gain insights from future pairs or the exteroceptive observations.

Domain Randomization

Similar to other papers, domain randomization is commonplace. External factors in simulation (friction, force on the Quadruped, etc.) are randomized before being fed to the policy.

Predicting Actions

  • Student Policy: To predict the action, we just use the $(t-1)$ output tokens from the transformer $(H_0, H_1, H_2 \dots H_{t-1})$.
  • Teacher Policy: The teacher policy has access to $H_e$. Due to causal masking, the attention mechanism incorporates all required info into the final privileged token $H_e$. Thus, for the teacher, action generation only depends on $H_e$.

In both cases, an MLP takes the respective token output and predicts the next action.

Training, Losses, and Exploration

Loss Functions

The training involves multiple loss terms:

  1. Trajectory Prediction Loss: Computed between the predicted trajectory and the actual expert trajectory.
  2. Imitation Loss: Calculated between the current action (student) and the future action given by the teacher (expert).

Mix-Ratio and Exploration

Since ULT trains both models together, it needs to decide which policy’s action to use at each step to ensure both are aptly explored.

  • A “mix-ratio” decides whether the teacher or student policy predicts the next action token.
  • In massively parallelized environments, this is implemented as a mask $M$, where each entry (0 to 1) decides the choice for that specific step.



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Branch Prediction: From CPUs to GPUs and TPUs
  • Breaking down SREGym!
  • Breaking down SREGym!
  • Let's Paint! Shall we?.
  • A simple and intuitive guide to using uv - an awesome tool from astral!