Skip to content

Callbacks

class
ruprompts.callbacks.FreezeTransformerUnfreezePrompt

Freezes all parameters but those of prompt provider.

on_train_begin
(args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs)

Event called at the beginning of training.

Source code in ruprompts/callbacks.py
def on_train_begin(
    self,
    args: TrainingArguments,
    state: TrainerState,
    control: TrainerControl,
    model: PreTrainedModel,
    **kwargs,
):
    for name, param in model.transformer.named_parameters():
        if PROMPT_PROVIDER_KEY_NAME in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

class
ruprompts.callbacks.ReduceCheckpoint

Reduces the checkpoint size by keeping only the weights of prompt provider.

on_save
(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)

Event called after a checkpoint save.

Source code in ruprompts/callbacks.py
def on_save(
    self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
    checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
    output_dir = os.path.join(args.output_dir, checkpoint_folder)
    weights_path = os.path.join(output_dir, WEIGHTS_NAME)
    weights = torch.load(weights_path)

    keys_to_remove = []
    for weight_key in weights:
        if PROMPT_PROVIDER_KEY_NAME not in weight_key:
            keys_to_remove.append(weight_key)

    for key in keys_to_remove:
        weights.pop(key)
    torch.save(weights, weights_path)

class
ruprompts.callbacks.SavePretrainedPrompt
(prompt: Prompt)

Saves the prompt as pretrained on checkpoint.

Parameters:

Name Type Description Default
prompt

Prompt

Prompt instance to be saved.

required

on_save
(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)

Event called after a checkpoint save.

Source code in ruprompts/callbacks.py
def on_save(
    self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
    checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
    output_dir = os.path.join(args.output_dir, checkpoint_folder)
    self.prompt.save_pretrained(output_dir)

class
ruprompts.callbacks.WBLogHydraConfig
(cfg)

Logs Hydra config to Weights and Biases on training start.

Parameters:

Name Type Description Default
cfg omegaconf.DictConfig

Config to be logged.

required

on_train_begin
(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)

Event called at the beginning of training.

Source code in ruprompts/callbacks.py
def on_train_begin(
    self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
    wandb.config.update({"hydra": omegaconf.OmegaConf.to_container(self.cfg)})