Skip to content
/ TAID Public

Official implementation of "TAID: Temporally Adaptive Interpolated Distillation for Efficient Knowledge Transfer in Language Models"

License

Notifications You must be signed in to change notification settings

SakanaAI/TAID

Repository files navigation

TAID (ICLR 2025)

📚 Paper | 🤗 Hugging Face | 📝 Blog [EN | JP]

overview

This is an official Pytorch implementation of "TAID: Temporally Adaptive Interpolated Distillation for Efficient Knowledge Transfer in Language Models".

Installation

pip install torch==2.2.2 --index-url https://download.pytorch.org/whl/cu121
pip install -r requirements.txt
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
python setup.py install

# make sure to login huggingface and wandb
huggingface-cli login
wandb login

We conducted our experiments in the following environment: Python Version 3.10.12 and CUDA Version 12.3 on 8 * H100 80GB.

Data Preparation

This is the script to prepare data for Phi-3-mini.

python prepare_ultrachat.py --model_type phi-3 --output_dir data

Training

We provide bash scripts for various methods in the scripts directory. For example, the scripts for the experiments distilling from Llama-2 to TinyLlama can be found in scripts/llama-2 directory. For instance, running the following command will execute training with TAID.

bash scripts/llama-2/taid.sh

Acknowledgement

We would like to thank the developers of the source models for their contributions and for making their work available. The implmentation of baseline losses in distil_losses and the sampling generation in sampler.py is based on the following repositories, and we are grateful for their work.

Citation

To cite our work, you can use the following:

@misc{sakana2025taid,
      title         = {TAID: Temporally Adaptive Interpolated Distillation for Efficient Knowledge Transfer in Language Models}, 
      author.       = {Makoto Shing and Kou Misaki and Han Bao and Sho Yokoi and Takuya Akiba},
      year          = {2025},
      eprint        = {2501.16937},
      archivePrefix = {arXiv},
      primaryClass  = {cs.LG},
      url           = {https://arxiv.org/abs/2501.16937}
}

About

Official implementation of "TAID: Temporally Adaptive Interpolated Distillation for Efficient Knowledge Transfer in Language Models"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published