Paper: LeWorldModel
Introduction
I was seeing some tweets about Yann LeCun and got interested in his current push around world models.
What are world models? To put it simply, a world model is a model’s simulation of reality, used to predict consequences and plan actions. This is different from a typical LLM, which is primarily trained to predict the next token in a sequence. World models are expected to predict change in the world.
A simple example:
LLM:
Learns from previous tokens -> predicts next token
World model:
Learns from current state + action -> predicts next state
LeCun believes that world models are a pathway to AGI, and I think there is merit to that claim. One of his most vocal points is that most human knowledge is not in text, but from non-linguistic cues like vision, motion, spatial reasoning, and interaction. For true AGI, a system needs some physical understanding of the world.
Imagine someone is sitting on a chair. If another person suddenly pulls the chair away, the person sitting will probably fall. This is obvious to us. But how does a baby learn this? Not through language. Babies learn by watching, touching, moving, falling, and interacting with the world.
This paper is a good introduction to how we can train a world model using Joint Embedding Predictive Architecture, or JEPA.
What is JEPA?
JEPA is a framework that learning world models use to predict the evolution of a system. We start with the current world state, compress it into an embedding, observe that something happens, and then predict a future compressed embedding.
Instead of predicting an image or video directly, the model predicts the next latent representation.
Current versions of JEPA include:
- Representation learning
- Hide part of the input, such as masking part of an image.
- Train the model to predict the embeddings of the hidden part.
- This does not really involve actions.
- Action-conditioned learning
- Learn from current observation + action -> future latent state.
- Use pretrained encoders, which makes collapse less likely, but limits the model to what the encoders see.
LeWM claims to train JEPAs from raw pixels directly with only two training objectives: prediction loss on future embeddings and Gaussian regularization on embeddings.
Once we have a world model, how do we use it for decision-making?
The older way was to use it as a fake simulator and train a policy in imagination. The policy learns that for a certain state, we should take a certain action.
world model simulates future outcomes
policy learns which actions are good
final policy maps state -> action
At test time, we just use the policy:
observation -> policy -> action
The newer way is to plan on the fly. At test time, the system asks:
What if I do actions [a1, a2, a3]?
Where will I end up?
What if I do actions [b1, b2, b3]?
Where will I end up?
Which action sequence gets me closest to the goal?
This is basically Model Predictive Control.
Introducing LeWorldModel
LeWorldModel, or LeWM, uses an offline dataset for training. It only uses unannotated trajectories of observations and actions. This is a bit like learning from a video dataset.
Without any reward, the model only sees what the world looks like at time t and at time t + 1 after action a has happened. With no task-specific bias in place, the goal is just to predict the next state and the dynamics of the environment.
Model Architecture
LeWM has two main components: an encoder and a predictor.
The encoder turns an observation into a latent representation. The predictor takes in a latent representation and an action, then predicts the embedding of the next latent observation. The predictor does not only take in one latent observation, but a history of observations.
The objective is to lower the sum of prediction loss and regularization loss.
Prediction loss measures the error between predicted future and real future:
predicted latent says: block moved right
actual latent says: block moved right
=> low loss
predicted latent says: block moved left
actual latent says: block moved right
=> high loss
This is cheatable. The encoder could map every image to the same embedding, making prediction loss low while the model remains useless. To avoid this, SIGReg is used to spread out the latent vectors instead of letting them collapse to a single point.
The full training path looks like this:
image_t
-> encoder
-> z_t + action_t
-> predictor
-> z_hat_{t+1}
actual image_{t+1}
-> encoder
-> z_{t+1}
loss = distance(z_hat_{t+1}, z_{t+1}) + anti-collapse regularization
Once training is done, the world model is frozen and used for latent planning. The model encodes the first image, predicts the next state for H steps, then encodes the expected end-state image and compares the two.
Essentially, the model tries many action sequences, calculates the losses for each one, chooses a set of promising sequences, samples more similar sequences, and repeats until the best plan is found.
One optimization worth mentioning is that a high H gives better results but is expensive and increases the probability of prediction errors compounding over time. Model Predictive Control mitigates this:
Plan 20 steps.
Execute first 2 steps.
Look again.
Plan another 20 steps.
Execute first 2 steps.
Repeat.
The rationale is that the model does not need to perfectly predict far into the future.
Physical Understanding in LeWM
A vital part of world models is that they understand the physical structure of an environment. But LeWM only gives latent embeddings, so there is no direct way to know whether those embeddings contain physical characteristics.
To test this, the authors use linear and non-linear probing. They train small models to convert latent embeddings into physical quantities.
The rough intuition:
Bad latent:
random numbers, no clear physical meaning
Good latent:
contains enough information to recover object positions, angles, etc.
Something cool is that despite the model not being trained to decode latent states back into original images, the decoder is able to recover the scene. This means the latent space contains meaningful information about the physical state.
One interesting experiment is the violation-of-expectation framework. The question is: does the world model get more surprised when the environment violates physical rules?
Surprise is defined as the gap between what the model predicted would happen and what actually happened. One experiment makes objects teleport and measures how surprised the model is.
Conclusion
Honestly, this was one of the hardest papers to read. It is filled with mathematical equations, and ChatGPT was helpful in breaking down the concepts for me.
Most people’s experience with AI is LLMs, where the next token is predicted. World models feel like their own paradigm. They are heavily modeled on how humans interact with the world, and they try to model reality itself.
Because of that, LeCun’s argument that world models could be a path toward AGI makes a lot more sense to me now. If intelligence requires more than language, if it requires physical understanding, planning, and the ability to predict consequences, then world models seem like an important piece of the puzzle.
I definitely do not understand every mathematical detail yet, but I feel like I now understand the high-level intuition behind why world models matter and why they could be a key piece in building more capable AI agents. The future is going to be interesting.