Skip to content

Commit

Permalink
move many values into the config
Browse files Browse the repository at this point in the history
  • Loading branch information
Zarxrax committed Feb 18, 2024
1 parent c64a256 commit 985af49
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 29 deletions.
38 changes: 34 additions & 4 deletions cutie/config/gui_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,28 @@ defaults:
- _self_
- model: base

# workspace configuration
# workspace path
workspace_root: ./workspace

force_cpu: False
amp: True
weights: weights/cutie-base-mega.pth

# RITM interactive segmentation settings
# All "size" parameters represent the length of the longer edge
ritm_weights: weights/coco_lvis_h18_itermask.pth
ritm_anime_weights: weights/aniclick_v2_h18_itermask.pth
ritm_max_size: 960
ritm_zoom_size: 512
ritm_expansion_ratio: 1.4
ritm_use_anime: False

# All "size" parameters represent the length of the shorter edge
# maximum internal processing size; reducing this speeds up processing
# "size" parameters represent the length of the shorter edge
max_internal_size: 480

# maximum size for extracting frames; the output will also be in this size
# reducing this mainly speed up I/O
# it should not be smaller than the internal size
# reducing this mainly speed up I/O, it should not be smaller than the internal size
max_overall_size: 2160
buffer_size: 20

Expand All @@ -35,6 +42,12 @@ output_fps: 23.976
output_quantizer: 14
output_refine: False

# refine edges settings. Blur must be an odd number.
erode_radius: 1
erode_blur: 3
dilate_radius: 1
dilate_blur: 3

# memory settings
use_long_term: True
mem_every: 5
Expand All @@ -54,5 +67,22 @@ long_term:
max_num_tokens: 4000
buffer_tokens: 2000

# settings for processing quality options. each one can change the max_internal_size and the max_num_tokens
default_quality: NormalQuality

LowQuality:
max_internal_size: 400
max_num_tokens: 3000
NormalQuality:
max_internal_size: 480
max_num_tokens: 4000
HighQuality:
max_internal_size: 540
max_num_tokens: 4000
UltraQuality:
max_internal_size: 720
max_num_tokens: 4000

# not sure what these are
save_aux: False
flip_aug: False
5 changes: 4 additions & 1 deletion cutie_roto.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ def get_arguments():

# general setup
torch.set_grad_enabled(False)
if torch.cuda.is_available():
if cfg.force_cpu:
device = 'cpu'
elif torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
device = 'mps'
else:
device = 'cpu'

args.device = device
log.info(f'Using device: {device}')

Expand Down
7 changes: 4 additions & 3 deletions gui/click_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@


class ClickController:
def __init__(self, checkpoint_path: str, device: str = 'cuda', max_size: int = 960):
def __init__(self, checkpoint_path: str, max_size: int = 960, zoom_size: int = 512, expansion_ratio: float = 1.4, device: str = 'cuda'):

model = utils.load_is_model(checkpoint_path, device, cpu_dist_maps=True)

# Predictor params
zoomin_params = {
'skip_clicks': 1,
'target_size': 512,
'expansion_ratio': 1.4,
'target_size': zoom_size,
'expansion_ratio': expansion_ratio,
}

predictor_params = {
Expand Down
12 changes: 8 additions & 4 deletions gui/exporter_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,10 @@ def convert_mask_to_binary(self, mask_folder: str, output_path: str, progress_ca
self.progressbar_update(1.0)

def refine_masks(self, mask_folder: str, soft_mask_folder: str, output_path: str, progress_callback=None) -> None:
erode_radius = self.cfg['erode_radius']
erode_blur = self.cfg['erode_blur']
dilate_radius = self.cfg['dilate_radius']
dilate_blur = self.cfg['dilate_blur']
refined_mask_path = path.join(output_path, 'refined_masks')
masks = [img for img in sorted(listdir(mask_folder)) if img.endswith(".png")]
soft_masks = [img for img in sorted(listdir(soft_mask_folder)) if img.endswith(".png")]
Expand All @@ -402,10 +406,10 @@ def refine_masks(self, mask_folder: str, soft_mask_folder: str, output_path: str
smsk = cv2.imread(soft_path, cv2.IMREAD_GRAYSCALE)

#refine mask
bmsk_erosion = cv2.erode(bmsk, None, iterations=3)
bmsk_dilation = cv2.dilate(bmsk, None, iterations=2)
bmsk_erosion = cv2.GaussianBlur(bmsk_erosion, (5, 5), 0)
bmsk_dilation = cv2.GaussianBlur(bmsk_dilation, (5, 5), 0)
bmsk_erosion = cv2.erode(bmsk, None, iterations=erode_radius)
bmsk_dilation = cv2.dilate(bmsk, None, iterations=dilate_radius)
bmsk_erosion = cv2.GaussianBlur(bmsk_erosion, (erode_blur, erode_blur), 0)
bmsk_dilation = cv2.GaussianBlur(bmsk_dilation, (dilate_blur, dilate_blur), 0)
add_mask = cv2.add(smsk, bmsk_erosion)
sub_mask = cv2.subtract(smsk, bmsk_dilation)
refined_mask = cv2.subtract(add_mask, sub_mask)
Expand Down
16 changes: 13 additions & 3 deletions gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,14 @@ def __init__(self, controller, cfg: DictConfig) -> None:
self.comboBox_quality.addItem(u"High")
self.comboBox_quality.addItem(u"Ultra")
self.comboBox_quality.setObjectName(u"comboBox_quality")
self.comboBox_quality.setCurrentText("Normal")
if cfg.default_quality == "LowQuality":
self.comboBox_quality.setCurrentText("Low")
elif cfg.default_quality == "HighQuality":
self.comboBox_quality.setCurrentText("High")
elif cfg.default_quality == "UltraQuality":
self.comboBox_quality.setCurrentText("Ultra")
else:
self.comboBox_quality.setCurrentText("Normal")
self.comboBox_quality.currentIndexChanged.connect(controller.on_quality_change)

self.modelselect_label = QLabel(u"Click Segmentation Model")
Expand All @@ -188,7 +195,10 @@ def __init__(self, controller, cfg: DictConfig) -> None:
self.comboBox_modelselect.setObjectName(u"comboBox_modelselect")
self.comboBox_modelselect.addItem(u"Standard")
self.comboBox_modelselect.addItem(u"Anime")
self.comboBox_modelselect.setCurrentText("Standard")
if cfg.ritm_use_anime is True:
self.comboBox_modelselect.setCurrentText("Anime")
else:
self.comboBox_modelselect.setCurrentText("Standard")
self.comboBox_modelselect.currentIndexChanged.connect(controller.on_modelselect_change)

# import background layer
Expand Down Expand Up @@ -325,7 +335,7 @@ def __init__(self, controller, cfg: DictConfig) -> None:
quality_area.addWidget(self.comboBox_quality)
right_area.addLayout(quality_area)

#quality combobox
#model combobox
modelselect_area = QHBoxLayout()
modelselect_area.setAlignment(Qt.AlignmentFlag.AlignBottom)
modelselect_area.addWidget(self.modelselect_label)
Expand Down
7 changes: 5 additions & 2 deletions gui/interactive_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def index_numpy_to_one_hot_torch(mask: np.ndarray, num_classes: int):
return F.one_hot(mask, num_classes=num_classes).permute(2, 0, 1).float()


"""
Some constants for visualization

# Some constants for visualization
"""
try:
if torch.cuda.is_available():
Expand All @@ -38,6 +38,9 @@ def index_numpy_to_one_hot_torch(mask: np.ndarray, num_classes: int):
device = torch.device("cpu")
except:
device = torch.device("cpu")
"""
# get existing device instead of detecting again
device = torch.cuda.current_device()

color_map_np = np.frombuffer(davis_palette, dtype=np.uint8).reshape(-1, 3).copy()
# scales for better visualization
Expand Down
29 changes: 17 additions & 12 deletions gui/main_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def __init__(self, cfg: DictConfig) -> None:
self.gui.text('Initialized.')
self.initialized = True

# set the quality per the config
self.on_quality_change()

# try to load the default overlay
#self._try_load_layer('./docs/uiuc.png')
#self.gui.set_object_color(self.curr_object)
Expand All @@ -132,15 +135,17 @@ def initialize_networks(self) -> None:
self.cutie = CUTIE(self.cfg).eval().to(self.device)
model_weights = torch.load(self.cfg.weights, map_location=self.device)
self.cutie.load_weights(model_weights)

self.click_ctrl = ClickController(self.cfg.ritm_weights, device=self.device)
if self.cfg.ritm_use_anime is True:
self.click_ctrl = ClickController(self.cfg.ritm_anime_weights, self.cfg.ritm_max_size, self.cfg.ritm_zoom_size, self.cfg.ritm_expansion_ratio, device=self.device)
else:
self.click_ctrl = ClickController(self.cfg.ritm_weights, self.cfg.ritm_max_size, self.cfg.ritm_zoom_size, self.cfg.ritm_expansion_ratio, device=self.device)

def on_modelselect_change(self):
if self.gui.comboBox_modelselect.currentText() == 'Standard':
self.click_ctrl = ClickController(self.cfg.ritm_weights, device=self.device)
self.click_ctrl = ClickController(self.cfg.ritm_weights, self.cfg.ritm_max_size, self.cfg.ritm_zoom_size, self.cfg.ritm_expansion_ratio, device=self.device)
self.gui.text('Standard segmentation model loaded.')
else:
self.click_ctrl = ClickController(self.cfg.ritm_anime_weights, device=self.device)
self.click_ctrl = ClickController(self.cfg.ritm_anime_weights, self.cfg.ritm_max_size, self.cfg.ritm_zoom_size, self.cfg.ritm_expansion_ratio, device=self.device)
self.gui.text('Anime segmentation model loaded.')

def hit_number_key(self, number: int):
Expand Down Expand Up @@ -533,17 +538,17 @@ def on_work_max_change(self):
def on_quality_change(self):
if self.initialized:
if self.gui.comboBox_quality.currentText() == 'Low':
self.gui.long_mem_max.setValue(3000)
self.gui.quality_box.setValue(400)
self.gui.long_mem_max.setValue(self.cfg.LowQuality.max_num_tokens)
self.gui.quality_box.setValue(self.cfg.LowQuality.max_internal_size)
elif self.gui.comboBox_quality.currentText() == 'Normal':
self.gui.long_mem_max.setValue(4000)
self.gui.quality_box.setValue(480)
self.gui.long_mem_max.setValue(self.cfg.NormalQuality.max_num_tokens)
self.gui.quality_box.setValue(self.cfg.NormalQuality.max_internal_size)
elif self.gui.comboBox_quality.currentText() == 'High':
self.gui.long_mem_max.setValue(4000)
self.gui.quality_box.setValue(540)
self.gui.long_mem_max.setValue(self.cfg.HighQuality.max_num_tokens)
self.gui.quality_box.setValue(self.cfg.HighQuality.max_internal_size)
elif self.gui.comboBox_quality.currentText() == 'Ultra':
self.gui.long_mem_max.setValue(4000)
self.gui.quality_box.setValue(720)
self.gui.long_mem_max.setValue(self.cfg.UltraQuality.max_num_tokens)
self.gui.quality_box.setValue(self.cfg.UltraQuality.max_internal_size)
self.update_config()


Expand Down

0 comments on commit 985af49

Please sign in to comment.