Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inversion #111

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/mflux/controlnet/flux_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def generate_image(
)

# 3.t Predict the noise
noise = self.transformer.predict(
noise = self.transformer.predict_with_t(
t=t,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
Expand Down
2 changes: 1 addition & 1 deletion src/mflux/dreambooth/optimization/dreambooth_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def _single_example_loss(flux: Flux1, config: RuntimeConfig, example: Example, r
) # fmt: off

# Predict the noise from timestep t
predicted_noise = flux.transformer.predict(
predicted_noise = flux.transformer.predict_with_t(
t=t,
prompt_embeds=example.prompt_embeds,
pooled_prompt_embeds=example.pooled_prompt_embeds,
Expand Down
101 changes: 91 additions & 10 deletions src/mflux/flux/flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
lora_scales: list[float] | None = None,
):
super().__init__()

self.lora_paths = lora_paths
self.lora_scales = lora_scales
self.model_config = model_config
Expand Down Expand Up @@ -66,13 +67,14 @@ def __init__(
lora_weights = WeightHandlerLoRA.load_lora_weights(transformer=self.transformer, lora_files=lora_paths, lora_scales=lora_scales) # fmt:off
WeightHandlerLoRA.set_lora_weights(transformer=self.transformer, loras=lora_weights)

def generate_image(
def invert(
self,
seed: int,
prompt: str,
seed: int,
init_image_path: Path,
config: Config = Config(),
stepwise_output_dir: Path = None,
) -> GeneratedImage:
) -> (mx.array, mx.array):
# Create a new runtime config based on the model type and input parameters
config = RuntimeConfig(config, self.model_config)
time_steps = tqdm(range(config.init_time_step, config.num_inference_steps))
Expand All @@ -85,29 +87,108 @@ def generate_image(
output_dir=stepwise_output_dir,
)

# 1. Create the initial latents
latents = LatentCreator.create_for_txt2img_or_img2img(seed, config, self.vae)
# 1. Create the initial latents from the image
image_latents = LatentCreator.encode_image(
init_image_path=init_image_path,
width=config.width,
height=config.height,
vae=self.vae,
) # fmt: off

# 2. Embed the prompt
t5_tokens = self.t5_tokenizer.tokenize(prompt)
clip_tokens = self.clip_tokenizer.tokenize(prompt)
prompt_embeds = self.t5_text_encoder(t5_tokens)
pooled_prompt_embeds = self.clip_text_encoder(clip_tokens)

latents = mx.array(image_latents)
for gen_step, t in enumerate(time_steps, 1):
try:
# 3.t Predict the noise
noise = self.transformer.predict(
t=t,
dt = config.sigmas[config.num_inference_steps - 1 - t] - config.sigmas[config.num_inference_steps - t]

# 3.t Predict the noise with higher order terms
noise1 = self.transformer.predict_with_sigma(
t=float(config.num_inference_steps - t),
sigma_t=config.sigmas[config.num_inference_steps - t],
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
hidden_states=latents,
config=config,
)
noise2 = self.transformer.predict_with_sigma(
t=float(config.num_inference_steps - t) - 0.5,
sigma_t=config.sigmas[config.num_inference_steps - t] + 0.5 * dt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
hidden_states=latents + noise1 * 0.5 * dt,
config=config,
)

# 4.t Take one denoise step
latents += dt * noise2

# Handle stepwise output if enabled
stepwise_handler.process_step(config.num_inference_steps - t, latents)

# Evaluate to enable progress tracking
mx.eval(latents)

except KeyboardInterrupt: # noqa: PERF203
stepwise_handler.handle_interruption()
raise StopImageGenerationException(f"Stopping image generation at step {t + 1}/{len(time_steps)}")

return latents, image_latents

def generate_image(
self,
seed: int,
prompt: str,
latents: mx.array,
config: Config = Config(),
stepwise_output_dir: Path = None,
) -> GeneratedImage:
# Create a new runtime config based on the model type and input parameters
config = RuntimeConfig(config, self.model_config)
time_steps = tqdm(range(config.init_time_step, config.num_inference_steps))
stepwise_handler = StepwiseHandler(
flux=self,
config=config,
seed=seed,
prompt=prompt,
time_steps=time_steps,
output_dir=stepwise_output_dir,
)

# 1. Embed the prompt
t5_tokens = self.t5_tokenizer.tokenize(prompt)
clip_tokens = self.clip_tokenizer.tokenize(prompt)
prompt_embeds = self.t5_text_encoder(t5_tokens)
pooled_prompt_embeds = self.clip_text_encoder(clip_tokens)

for gen_step, t in enumerate(time_steps, 1):
try:
dt = config.sigmas[t + 1] - config.sigmas[t]
latents += noise * dt

# 2.t Predict the noise with higher order terms
noise1 = self.transformer.predict_with_sigma(
t=float(t),
sigma_t=config.sigmas[t],
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
hidden_states=latents,
config=config,
)
noise2 = self.transformer.predict_with_sigma(
t=float(t) + 0.5,
sigma_t=config.sigmas[t] + 0.5 * dt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
hidden_states=latents + noise1 * 0.5 * dt,
config=config,
)

# 3.t Take one denoise step
latents += dt * noise2

# Handle stepwise output if enabled
stepwise_handler.process_step(gen_step, latents)
Expand All @@ -119,7 +200,7 @@ def generate_image(
stepwise_handler.handle_interruption()
raise StopImageGenerationException(f"Stopping image generation at step {t + 1}/{len(time_steps)}")

# 5. Decode the latent array and return the image
# 4. Decode the latent array and return the image
latents = ArrayUtil.unpack_latents(latents=latents, height=config.height, width=config.width)
decoded = self.vae.decode(latents)
return ImageUtil.to_image(
Expand Down
18 changes: 18 additions & 0 deletions src/mflux/flux/v_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import mlx.core as mx
import numpy as np


class VCache:
is_inverting = True
v_cache = {}
t_max = 5

@staticmethod
def save_dict(data_dict, filename):
np_dict = {k: v.tolist() for k, v in data_dict.items()}
np.savez_compressed(filename, **np_dict)

@staticmethod
def load_dict(filename):
data = np.load(filename)
return {k: mx.array(v) for k, v in data.items()}
67 changes: 40 additions & 27 deletions src/mflux/generate.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,60 @@
import time
from pathlib import Path

from mflux import Config, Flux1, ModelConfig, StopImageGenerationException
from mflux.ui.cli.parsers import CommandLineParser
from mflux.flux.v_cache import VCache

image_path = "/Users/filipstrand/Desktop/cat.png"
source_prompt = "A cat"
target_prompt = "A sleeping cat"
height = 256
width = 256
steps = 20
seed = 2
source_guidance = 1.5
target_guidance = 5.5
VCache.t_max = 10

def main():
# fmt: off
parser = CommandLineParser(description="Generate an image based on a prompt.")
parser.add_model_arguments(require_model_arg=False)
parser.add_lora_arguments()
parser.add_image_generator_arguments(supports_metadata_config=True)
parser.add_image_to_image_arguments(required=False)
parser.add_output_arguments()
args = parser.parse_args()

def main():
# Load the model
flux = Flux1(
model_config=ModelConfig.from_alias(args.model),
quantize=args.quantize,
local_path=args.path,
lora_paths=args.lora_paths,
lora_scales=args.lora_scales,
model_config=ModelConfig.FLUX1_DEV,
quantize=4,
)

try:
# Generate an image
# Invert an existing image
VCache.is_inverting = True
inverted_latents, encoded_image = flux.invert(
seed=seed,
prompt=source_prompt,
init_image_path=Path(image_path),
stepwise_output_dir=Path("/Users/filipstrand/Desktop/backward"),
config=Config(
num_inference_steps=steps,
height=height,
width=width,
guidance=source_guidance,
),
)

# Generate a new image based on the inverted one
VCache.is_inverting = False
image = flux.generate_image(
seed=int(time.time()) if args.seed is None else args.seed,
prompt=args.prompt,
stepwise_output_dir=Path(args.stepwise_image_output_dir) if args.stepwise_image_output_dir else None,
seed=seed,
prompt=target_prompt,
latents=inverted_latents,
stepwise_output_dir=Path("/Users/filipstrand/Desktop/forward"),
config=Config(
num_inference_steps=args.steps,
height=args.height,
width=args.width,
guidance=args.guidance,
init_image_path=args.init_image_path,
init_image_strength=args.init_image_strength,
num_inference_steps=steps,
height=height,
width=width,
guidance=target_guidance,
),
)

# Save the image
image.save(path=args.output, export_json_metadata=args.metadata)
image.save(path="edited.png")
except StopImageGenerationException as stop_exc:
print(stop_exc)

Expand Down
31 changes: 24 additions & 7 deletions src/mflux/latent_creator/latent_creator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import mlx.core as mx
from mlx import nn

Expand Down Expand Up @@ -35,21 +37,36 @@ def create_for_txt2img_or_img2img(
return pure_noise
else:
# Image2Image
user_image = ImageUtil.load_image(runtime_conf.config.init_image_path).convert("RGB")
scaled_user_image = ImageUtil.scale_to_dimensions(
image=user_image,
target_width=runtime_conf.width,
target_height=runtime_conf.height,
latents = LatentCreator.encode_image(
init_image_path=runtime_conf.config.init_image_path,
height=runtime_conf.height,
width=runtime_conf.width,
vae=vae,
)
encoded = vae.encode(ImageUtil.to_array(scaled_user_image))
latents = ArrayUtil.pack_latents(latents=encoded, height=runtime_conf.height, width=runtime_conf.width)
sigma = runtime_conf.sigmas[runtime_conf.init_time_step]
return LatentCreator.add_noise_by_interpolation(
clean=latents,
noise=pure_noise,
sigma=sigma
) # fmt: off

@staticmethod
def encode_image(
init_image_path: Path,
width: int,
height: int,
vae: nn.Module,
):
user_image = ImageUtil.load_image(init_image_path).convert("RGB")
scaled_user_image = ImageUtil.scale_to_dimensions(
image=user_image,
target_width=width,
target_height=height,
)
encoded = vae.encode(ImageUtil.to_array(scaled_user_image))
latents = ArrayUtil.pack_latents(latents=encoded, height=height, width=width)
return latents

@staticmethod
def add_noise_by_interpolation(clean: mx.array, noise: mx.array, sigma: float) -> mx.array:
return (1 - sigma) * clean + sigma * noise
19 changes: 17 additions & 2 deletions src/mflux/models/transformer/single_block_attention.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,40 @@
import mlx.core as mx
from mlx import nn

from mflux.flux.v_cache import VCache


class SingleBlockAttention(nn.Module):
head_dimension = 128
batch_size = 1
num_heads = 24

def __init__(self):
def __init__(self, layer):
super().__init__()
self.layer = layer
self.to_q = nn.Linear(3072, 3072)
self.to_k = nn.Linear(3072, 3072)
self.to_v = nn.Linear(3072, 3072)
self.norm_q = nn.RMSNorm(128)
self.norm_k = nn.RMSNorm(128)

def __call__(self, hidden_states: mx.array, image_rotary_emb: mx.array) -> (mx.array, mx.array):
def __call__(self, t: float, hidden_states: mx.array, image_rotary_emb: mx.array) -> (mx.array, mx.array):
query = self.to_q(hidden_states)
key = self.to_k(hidden_states)

# Handle the values from inversion
key_hash = hash((t, id(self)))
value = self.to_v(hidden_states)

if self.layer > 15:
if VCache.is_inverting:
if t <= VCache.t_max:
VCache.v_cache[key_hash] = mx.array(value)
else:
if t <= VCache.t_max:
value = VCache.v_cache.get(key_hash, None)
value = value if value is not None else self.to_v(hidden_states)

query = mx.transpose(mx.reshape(query, (1, -1, 24, 128)), (0, 2, 1, 3))
key = mx.transpose(mx.reshape(key, (1, -1, 24, 128)), (0, 2, 1, 3))
value = mx.transpose(mx.reshape(value, (1, -1, 24, 128)), (0, 2, 1, 3))
Expand Down
4 changes: 3 additions & 1 deletion src/mflux/models/transformer/single_transformer_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ def __init__(self, layer):
self.layer = layer
self.norm = AdaLayerNormZeroSingle()
self.proj_mlp = nn.Linear(3072, 4 * 3072)
self.attn = SingleBlockAttention()
self.attn = SingleBlockAttention(layer)
self.proj_out = nn.Linear(3072 + 4 * 3072, 3072)

def __call__(
self,
t: float,
hidden_states: mx.array,
text_embeddings: mx.array,
rotary_embeddings: mx.array,
Expand All @@ -26,6 +27,7 @@ def __call__(
norm_hidden_states, gate = self.norm(x=hidden_states, text_embeddings=text_embeddings)
mlp_hidden_states = nn.gelu_approx(self.proj_mlp(norm_hidden_states))
attn_output = self.attn(
t=t,
hidden_states=norm_hidden_states,
image_rotary_emb=rotary_embeddings,
)
Expand Down
Loading