I’m fascinated by reasoning: reasoning allows a model to decompose a challenging computation into smaller steps. This Quiet-STaR model learns which decomposition and planning steps are effective in predicting future text, generalizing the 2022 “Self-Taught Reasoner STaR” model, where rationales/context were inferred from few-shot question-answer QA pairs. Lack of curated QA datasets limited generalizability of STaR.
Quiet-STaR generates rationales after every token to explain future
text (think), mixing the future-text predictions with and without rationales (talk), and then learning to generate better rationales using REINFORCE best reward (learn).
Method: Tokenwise parallel sampling algo: learnable tokens indicating a thought’s start + end, extended teacher-forcing technique.
Result: Zero-shot (without dataset specific finetuning) improvements on GSM8K (5.9%→10.9%) grade school math word problems, and CommonsenseQA (36.3%→47.2%). Improvements consistently increase with the number of tokens used in the LM’s internal thoughts.
Model: Mistral 7B.
Dataset: OpenWebMath, Colossal Clean Crawled Corpus. No manual annotation required.
Optimizations to try:
- Parallel sampling algorithm -> scalable training procedure, generates rationales from all token positions in a given string
- Meta-tokens aka “function vectors” at the start and end of each thought -> LM learns when to generate a rationale and when to make a prediction based on that rationale
- Mixing head applied retrospectively -> how much to incorporate the
next-token prediction from a given thought into the next-token prediction - Non-myopic loss, including multiple tokens ahead for language modeling, improves the effect of thinking. Formulate the objective as accurately predicting the remaining sequence, rather than only the next token.
Related models:
- REINFORCE best reward is a 1992 paper by Ronald J Williams. “Simple statistical gradient-following algorithms for connectionist reinforcement learning”. Machine learning, 8:229–256, 1992.
- V-STaR (Hosseini et al., 2024) is “process-based” supervision where incorrect reasoning traces were filtered.
- TRICE (Hoffman et al., 2024) trains a verifier to guide generation
also improves performance, maximizes the marginal likelihood of the correct answer.
Resource: Paper on arXiv: https://arxiv.org/abs/2403.09629