Callbacks
class
ruprompts.callbacks.FreezeTransformerUnfreezePrompt
ruprompts.callbacks.FreezeTransformerUnfreezePrompt
Freezes all parameters but those of prompt provider.
on_train_begin
(args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs)
on_train_begin
(args: TrainingArguments, state: TrainerState, control: TrainerControl, model: PreTrainedModel, **kwargs)
Event called at the beginning of training.
Source code in ruprompts/callbacks.py
def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
model: PreTrainedModel,
**kwargs,
):
for name, param in model.transformer.named_parameters():
if PROMPT_PROVIDER_KEY_NAME in name:
param.requires_grad = True
else:
param.requires_grad = False
class
ruprompts.callbacks.ReduceCheckpoint
ruprompts.callbacks.ReduceCheckpoint
Reduces the checkpoint size by keeping only the weights of prompt provider.
on_save
(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)
on_save
(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)
Event called after a checkpoint save.
Source code in ruprompts/callbacks.py
def on_save(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
output_dir = os.path.join(args.output_dir, checkpoint_folder)
weights_path = os.path.join(output_dir, WEIGHTS_NAME)
weights = torch.load(weights_path)
keys_to_remove = []
for weight_key in weights:
if PROMPT_PROVIDER_KEY_NAME not in weight_key:
keys_to_remove.append(weight_key)
for key in keys_to_remove:
weights.pop(key)
torch.save(weights, weights_path)
class
ruprompts.callbacks.SavePretrainedPrompt
(prompt: Prompt)
ruprompts.callbacks.SavePretrainedPrompt
(prompt: Prompt)
Saves the prompt as pretrained on checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
prompt |
Prompt instance to be saved. |
required |
on_save
(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)
on_save
(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)
Event called after a checkpoint save.
Source code in ruprompts/callbacks.py
def on_save(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}"
output_dir = os.path.join(args.output_dir, checkpoint_folder)
self.prompt.save_pretrained(output_dir)
class
ruprompts.callbacks.WBLogHydraConfig
(cfg)
ruprompts.callbacks.WBLogHydraConfig
(cfg)
Logs Hydra config to Weights and Biases on training start.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
cfg |
omegaconf.DictConfig
|
Config to be logged. |
required |
on_train_begin
(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)
on_train_begin
(args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs)
Event called at the beginning of training.
Source code in ruprompts/callbacks.py
def on_train_begin(
self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs
):
wandb.config.update({"hydra": omegaconf.OmegaConf.to_container(self.cfg)})