Skip to content

Commit

Permalink
add log for training and validation
Browse files Browse the repository at this point in the history
  • Loading branch information
shuishen112 committed Jul 15, 2024
1 parent 665bfc4 commit 330813c
Showing 1 changed file with 53 additions and 4 deletions.
57 changes: 53 additions & 4 deletions mttl/models/expert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ def calculate_DPO_loss(
beta * (original_prefered_relative_logprob - disprefered_relative_logprob)
).mean(dim=-1)

return loss
return loss, reward_accuracies, reward_margins


def get_log_prob(logits, labels):
Expand Down Expand Up @@ -683,13 +683,30 @@ def training_step(self, batch, _):
labels=prompt_disprefered_ids,
)

loss = calculate_DPO_loss(
loss, reward_accuracies, reward_margins = calculate_DPO_loss(
model_prefered_log_prob,
model_disprefered_log_prob,
ref_prefered_log_prob,
ref_disprefered_log_prob,
beta=0.1,
)
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=True)

self.log(
"train/reward_accuracies",
reward_accuracies,
on_step=True,
on_epoch=True,
prog_bar=True,
)
self.log(
"train/reward_margins",
reward_margins,
on_step=True,
on_epoch=True,
prog_bar=True,
)

return loss

def validation_step(self, batch, _):
Expand Down Expand Up @@ -729,13 +746,30 @@ def validation_step(self, batch, _):
labels=prompt_disprefered_ids,
)

loss = calculate_DPO_loss(
loss, reward_accuracies, reward_margins = calculate_DPO_loss(
model_prefered_log_prob,
model_disprefered_log_prob,
ref_prefered_log_prob,
ref_disprefered_log_prob,
beta=0.1,
)

self.log("val/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log(
"val/reward_accuracies",
reward_accuracies,
on_step=True,
on_epoch=True,
prog_bar=True,
)
self.log(
"val/reward_margins",
reward_margins,
on_step=True,
on_epoch=True,
prog_bar=True,
)

return loss

def test_step(self, batch, _):
Expand Down Expand Up @@ -775,13 +809,28 @@ def test_step(self, batch, _):
labels=prompt_disprefered_ids,
)

loss = calculate_DPO_loss(
loss, reward_accuracies, reward_margins = calculate_DPO_loss(
model_prefered_log_prob,
model_disprefered_log_prob,
ref_prefered_log_prob,
ref_disprefered_log_prob,
beta=0.1,
)
self.log("test/loss", loss, on_step=True, on_epoch=True, prog_bar=True)
self.log(
"test/reward_accuracies",
reward_accuracies,
on_step=True,
on_epoch=True,
prog_bar=True,
)
self.log(
"test/reward_margins",
reward_margins,
on_step=True,
on_epoch=True,
prog_bar=True,
)
return loss


Expand Down

0 comments on commit 330813c

Please sign in to comment.