Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difference between checkpoint FASTSAM3D and Finetuned-SAMMED3D #6

Open
MinxuanQin opened this issue Sep 10, 2024 · 9 comments
Open

Comments

@MinxuanQin
Copy link

Thank you for sharing the excellent code and checkpoints! I have run the code described in Readme.md and would like to determine whether I correctly understood them.

The current version of distillation.py and validate_student.py use an ImageEncoder with so-called "woatt" attention (window attention), not with 3D sparse flash attention. The validate_student.py file loads the tiny image encoder (first uploaded checkpoint on Github) as the image encoder; the remaining parts use the fine-tuned teacher model (the second uploaded checkpoint "Finetuned-SAMMED3D"). Does the third checkpoint, "FASTSAM3D," combine the tiny encoder and rest part together?

I think those checkpoints do not use build_sam3D_flash.py, build_sam3D_dilatedattention.py, and build_3D_decoder.py. Is it right? Does the checkpoint perform best among all encoder and decoder structure versions? Thank you!

@skill-diver
Copy link
Collaborator

The flash attention part is just used for inference, not for distillation. You could feel free to use flash attention for your inference.

@MinxuanQin
Copy link
Author

Thank you for your reply! So you have distilled a lightweight image encoder with only 6 layers, where the first two layers does not contain attention layers. For the inference, there are no checkpoints with flash attention available; I can distill an image encoder with flash attention and then use it for inference. Do I understand it correctly?

@skill-diver
Copy link
Collaborator

You are correct except one point: You could use our checkpoint to Inference, it supports flash attention.

@MinxuanQin
Copy link
Author

Thank you for your quick reply! I am not familiar to flash attention, so it is maybe a silly question: Based on your answer and the code I think flash attention is an efficient attention computation mechanism, does not change the network architecture. So the checkpoint supports both with and without flash attention, right?

@MinxuanQin
Copy link
Author

I have another question regarding to the distillation process: From utils/prepare_nnunet.py the images and labels from one dataset shall be stored under label_name/dataset_name/imagesTr and label_name/dataset_name/labelsTr, but preparelabel.py and validation.py only get directory name like label_name because they use all_dataset_paths = glob(join(args.test_data_path)) not all_dataset_paths = glob(join(args.test_data_path,'*','*')).

For the distillation process, could I use the image only once, not use data generated from utils/prepare_uunet.py, because the script generates duplicated images for each class of a single subject, and preparelabel.py does not need label?

Thank you very much for your help!

@skill-diver
Copy link
Collaborator

Thank you for your quick reply! I am not familiar to flash attention, so it is maybe a silly question: Based on your answer and the code I think flash attention is an efficient attention computation mechanism, does not change the network architecture. So the checkpoint supports both with and without flash attention, right?

You are right.

@skill-diver
Copy link
Collaborator

I have another question regarding to the distillation process: From utils/prepare_nnunet.py the images and labels from one dataset shall be stored under label_name/dataset_name/imagesTr and label_name/dataset_name/labelsTr, but preparelabel.py and validation.py only get directory name like label_name because they use all_dataset_paths = glob(join(args.test_data_path)) not all_dataset_paths = glob(join(args.test_data_path,'*','*')).

For the distillation process, could I use the image only once, not use data generated from utils/prepare_uunet.py, because the script generates duplicated images for each class of a single subject, and preparelabel.py does not need label?

Thank you very much for your help!

You need to use prepare_uunet.py. The model need to learn from this preprocessed images (crops, registration,etc is necessary).

@MinxuanQin
Copy link
Author

Got it. Thank you very much!

@MinxuanQin
Copy link
Author

I have a question regarding to the distillation loss. From the paper, the objective of the layer-wise progressive distillation process is described as

$$E_x (\frac{1}{k} \sum_{i=1}^{k} \Vert f_{teacher}^{(2i)} (x) - f_{student}^{(i)} (x) \Vert )$$

, where $k$ varies from 1 to 6 based on current and total training iterations. From the code distillation.py, I think the variable curlayer from the class BaseTrainer plays the role of $k$, but the loss in this case is defined as loss = self.seg_loss(output[self.curlayer], label[self.curlayer]), where only L2 norm in the current layer is computed, not from $i=1$ to $i=k$ from my point of view.

In addition, I have read that the iterations is set to 36 for the first laye-wise distillation process from the paper. I would like to know how many iterations were set for the logit-level distillation process. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants