Skip to content

Commit

Permalink
fix language input only
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed Dec 29, 2024
1 parent 16c579f commit 3d65478
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 30 deletions.
8 changes: 4 additions & 4 deletions mlx_vlm/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,11 @@ def main():
output = generate(
model,
processor,
args.image,
prompt,
args.temp,
args.max_tokens,
args.verbose,
image=args.image,
temp=args.temp,
max_tokens=args.max_tokens,
verbose=args.verbose,
**kwargs,
)
if not args.verbose:
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/idefics2/idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def get_input_embeddings(
pixel_attention_mask: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.embed_tokens(input_ids)

inputs_embeds = self.language_model.embed_tokens(input_ids)

Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/idefics3/idefics3.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def get_input_embeddings(
pixel_attention_mask: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.embed_tokens(input_ids)

inputs_embeds = self.language_model.embed_tokens(input_ids)

Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/llava/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_input_embeddings(
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids)

# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/llava_bunny/llava_bunny.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_input_embeddings(
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids)

inputs_embeds = self.language_model.model.embed_tokens(input_ids)

Expand Down
2 changes: 1 addition & 1 deletion mlx_vlm/models/llava_next/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def get_input_embeddings(
pixel_values: Optional[mx.array] = None,
):
if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids)

# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
Expand Down
11 changes: 6 additions & 5 deletions mlx_vlm/models/qwen2_vl/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def get_input_embeddings(
):

if pixel_values is None:
return self.language_model(input_ids)
return self.language_model.model.embed_tokens(input_ids)

dtype = self.vision_tower.patch_embed.proj.weight.dtype
pixel_values = pixel_values.astype(dtype)

# Get the input embeddings from the language model
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
Expand Down Expand Up @@ -97,10 +100,8 @@ def __call__(
**kwargs,
):
image_grid_thw = kwargs.pop("image_grid_thw", None)
image_grid_thw = mx.array(image_grid_thw)

dtype = self.vision_tower.patch_embed.proj.weight.dtype
pixel_values = pixel_values.astype(dtype)
if image_grid_thw is not None:
image_grid_thw = mx.array(image_grid_thw)

input_embddings = self.get_input_embeddings(
input_ids, pixel_values, image_grid_thw
Expand Down
41 changes: 25 additions & 16 deletions mlx_vlm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from io import BytesIO
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import mlx.core as mx
import mlx.nn as nn
Expand Down Expand Up @@ -963,9 +963,8 @@ def _step(y, **kwargs):
def stream_generate(
model: nn.Module,
processor: PreTrainedTokenizer,
image: str,
prompt: str,
image_processor=None,
image: Union[str, List[str]] = None,
max_tokens: int = 100,
**kwargs,
) -> Union[str, Generator[str, None, None]]:
Expand All @@ -986,8 +985,10 @@ def stream_generate(
if hasattr(processor, "image_processor") and isinstance(
processor.image_processor, BaseImageProcessor
):
prompt_tokens = mx.array(processor.encode(prompt))
tokenizer = processor
else:
prompt_tokens = mx.array(processor.tokenizer.encode(prompt))
tokenizer = processor.tokenizer

resize_shape = kwargs.pop("resize_shape", None)
Expand All @@ -997,17 +998,25 @@ def stream_generate(
image_token_index = None

# Prepare inputs
inputs = prepare_inputs(processor, image, prompt, image_token_index, resize_shape)
input_ids, pixel_values, mask = (
inputs["input_ids"],
inputs["pixel_values"],
inputs["mask"],
)
kwargs = {
k: v
for k, v in inputs.items()
if k not in ["input_ids", "pixel_values", "mask"]
}
if not image:
input_ids = prompt_tokens[None, :]
pixel_values = None
mask = None
kwargs = {}
else:
inputs = prepare_inputs(
processor, image, prompt, image_token_index, resize_shape
)
input_ids, pixel_values, mask = (
inputs["input_ids"],
inputs["pixel_values"],
inputs["attention_mask"],
)
kwargs = {
k: v
for k, v in inputs.items()
if k not in ["input_ids", "pixel_values", "mask"]
}

detokenizer = processor.detokenizer

Expand All @@ -1030,8 +1039,8 @@ def stream_generate(
def generate(
model: nn.Module,
processor: PreTrainedTokenizer,
image: str,
prompt: str,
image: Union[str, List[str]] = None,
temp: float = 0.0,
max_tokens: int = 100,
verbose: bool = False,
Expand Down Expand Up @@ -1078,7 +1087,7 @@ def generate(
else:
image_token_index = None

if image == []:
if not image:
input_ids = prompt_tokens[None, :]
pixel_values = None
mask = None
Expand Down

0 comments on commit 3d65478

Please sign in to comment.