Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quick Code Review: Auto-vectorization #2

Open
gaxler opened this issue Aug 1, 2023 · 8 comments
Open

Quick Code Review: Auto-vectorization #2

gaxler opened this issue Aug 1, 2023 · 8 comments

Comments

@gaxler
Copy link

gaxler commented Aug 1, 2023

Hi Sasha! Nice to see your take on the llama2.rs!
I did a port of Anrej’s llama.c here

Had a chance to go over the code and compare to my version, the only thing I want to mention is that you can make the compiler auto-vectorize some computations (most notably matmul)

Did a quick benchmark of our implementations. Your's runs at ~52t/s mine runs at ~75t/s (on 2CPU/4Gb codespaces VM, running stories15M model). My guess is that most of the difference is because the compiler can’t auto-vectorize your matmul implementation.

A good way to help the compiler to auto-vectorize is to use iterators as much as possible. Key idea is to replace the following loop with an iterator:

  xout.par_iter_mut().enumerate().for_each(|(i, v)| {
        let mut val = 0.0;
        for j in 0..n {
            val += w[i * n + j] * x[j];
        }
        *v = val;
    })

So doing the following closes most of the gap and puts you implementation at ~74t/s

xout.par_iter_mut().enumerate().for_each(|(i, v)| {
        *v= w[i*n..(i+1)*n] 
            .iter()
            .zip(x.iter())
            .fold(0f32, |acc, (&_w, &_x)| acc + _w * _x);

You can force the compiler to try and auto-vectorize with avx2 by passing the compiler flags:
RUSTFLAGS=“-C target-feature=+avx2,+fma"

@srush
Copy link
Owner

srush commented Aug 1, 2023

Amazing, that's really helpful to know. Thanks for pointing it out.

Do you plan on continuing to work on this? Was planning on moving on, but now I'm kind of curious to implement quantization to scale to bigger models. Would be happy to collaborate if you were interested.

@srush
Copy link
Owner

srush commented Aug 1, 2023

Nice. This bumped me up from 0.92 t/s to 1.02 t/2 on llama2 7B.

@gaxler
Copy link
Author

gaxler commented Aug 1, 2023

Amazing, that's really helpful to know. Thanks for pointing it out.

Do you plan on continuing to work on this? Was planning on moving on, but now I'm kind of curious to implement quantization to scale to bigger models. Would be happy to collaborate if you were interested.

Yeah, that's what I hope to do before moving on. Would love to colaborate. I started playing with some prototypes on this branch (it's not very reader friendly yet)

Nice. This bumped me up from 0.92 t/s to 1.02 t/2 on llama2 7B.

Nice! I wonder why the speed-up is so small compared to the 15M mode. Maybe the CPU waits on mmap page swaps?

@srush
Copy link
Owner

srush commented Aug 1, 2023

Nice, I will try to catch up on your code.

Some of the HF people recommended trying to do GPTQ inference (quant-full mat-vec). Which version are you doing?

@gaxler
Copy link
Author

gaxler commented Aug 2, 2023

My code is mostly experimenting with ways to do a clean matmul interface. Did a naive rowwise i8 quantization of weights and matmuls that gets accumulated to f32. But that's really just the first thing that poped to mind.

@srush
Copy link
Owner

srush commented Sep 2, 2023

hi! I saw that you are also a maintainer of Triton and worked on the AoT compiler. I'm playing around with trying to set this project up to use Triton just to learn it. Do you have any tips for getting this to work? I tried exporting PTX which worked reasonably well at first, but I think I am running into issues with calling into it from Rust. Curious if you had pointers to recommended ways to do it?

My hacky code: https://github.com/srush/llama2.rs/pull/35/files#diff-7c199e27f9cec983de845ad01b4fd4e558534ee33fd49d8134cbab879361af67R158

@gaxler
Copy link
Author

gaxler commented Sep 2, 2023

What are the issues that you have? Is that slowness or something Triton related?

I left some comments in the PR.

@srush
Copy link
Owner

srush commented Sep 2, 2023

Thanks, once I got it running it was fast, but then when I tried to further optimize the Triton code, the rust version went out of sync with the python version. Trying to make a minimal example.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants