Kkit.llm_utils.fine_tune_utils
1from pydantic import BaseModel 2from typing import Dict, Optional 3import torch 4from typing import List, Union 5from transformers import ( 6 AutoModelForCausalLM, 7 AutoTokenizer, 8 TrainerCallback 9) 10import threading 11import os 12from datasets import load_dataset 13from peft import LoraConfig, get_peft_model, PeftModel 14from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM 15 16 17# Global training state tracker 18class TrainingState: 19 def __init__(self): 20 self.lock = threading.Lock() 21 self.current_task: Optional[Dict] = None 22 23 def update_state(self, **kwargs): 24 with self.lock: 25 if self.current_task is None: 26 self.current_task = {} 27 self.current_task.update(kwargs) 28 29 def get_state(self): 30 with self.lock: 31 return self.current_task.copy() if self.current_task else None 32 33training_state = TrainingState() 34 35# Training progress callback 36class ProgressCallback(TrainerCallback): 37 def on_train_begin(self, args, state, control, **kwargs): 38 training_state.update_state( 39 total_epochs=state.num_train_epochs, 40 total_steps=state.max_steps 41 ) 42 43 def on_epoch_end(self, args, state, control, **kwargs): 44 training_state.update_state( 45 current_epoch=int(state.epoch), 46 current_step=state.global_step 47 ) 48 49 def on_step_end(self, args, state, control, **kwargs): 50 training_state.update_state( 51 current_step=state.global_step, 52 total_steps=state.max_steps 53 ) 54 55 def on_log(self, args, state, control, logs=None, **kwargs): 56 if torch.cuda.is_available(): 57 logs["gpu_alloc_mem"] = torch.cuda.memory_allocated() / 1024**2 58 logs["gpu_reserved_mem"] = torch.cuda.memory_reserved() / 1024**2 59 else: 60 logs["gpu_alloc_mem"] = 0 61 logs["gpu_reserved_mem"] = 0 62 63TRAIN_ROUND = 0 64 65# Training configuration model 66class TrainConfig(BaseModel): 67 model_name: str = "Qwen/Qwen2.5-0.5B" 68 lora_path: Optional[str] = None # Path to existing LoRA model, train from scratch if None, train from existing model if not None 69 lora_rank: int = 8 70 lora_alpha: int = 32 71 lora_dropout: float = 0.05 72 epochs: int = 3 73 batch_size: int = 4 74 learning_rate: float = 3e-4 75 max_length: Optional[int] = None 76 model_save_path: Optional[str] = None 77 response_template: Optional[str] = "<|im_start|>assistant\n" # Template for response generation 78 lora_target_modules: Union[List[str], str] = "all-linear" 79 lora_modules_to_save: Optional[List[str]] = None #["lm_head", "embed_token"] 80 tokenizer_padding_side: Optional[str] = "left" 81 attn_implementation: str = "flash_attention_2" 82 model_load_torch_dtype: str = "auto" 83 train_arg_bf16: bool = True 84 train_arg_fp16: bool = False 85 train_round: int = TRAIN_ROUND 86 87class MergeConfig(BaseModel): 88 model_name: str 89 lora_path: str 90 model_output: str 91 92# Core training function 93def train_model(config: TrainConfig, dataset_path: str, base_path: str): 94 # try: 95 training_state.update_state( 96 status="training", 97 message="Initializing model..." 98 ) 99 100 # Handle model save path 101 if config.model_save_path: 102 if os.path.isabs(config.model_save_path): 103 model_path = config.model_save_path 104 else: 105 model_path = os.path.join(base_path, config.model_save_path) 106 else: 107 model_path = os.path.join(base_path, "lora_finetuned") 108 109 # Ensure directory exists with write permission 110 os.makedirs(model_path, exist_ok=True) 111 if not os.access(model_path, os.W_OK): 112 raise RuntimeError(f"No write permission for {model_path}") 113 114 # Load model 115 tokenizer = AutoTokenizer.from_pretrained(config.model_name) 116 if config.tokenizer_padding_side is not None: 117 tokenizer.padding_side = config.tokenizer_padding_side 118 base_model = AutoModelForCausalLM.from_pretrained( 119 config.model_name, 120 device_map="auto", 121 torch_dtype=config.model_load_torch_dtype, 122 attn_implementation=config.attn_implementation 123 ) 124 125 if config.lora_path: 126 # load LoRA model 127 model = PeftModel.from_pretrained( 128 base_model, 129 config.lora_path, 130 device_map="auto" 131 ) 132 else: 133 # Configure new LoRA 134 lora_config = LoraConfig( 135 r=config.lora_rank, 136 lora_alpha=config.lora_alpha, 137 target_modules=config.lora_target_modules, 138 lora_dropout=config.lora_dropout, 139 bias="none", 140 modules_to_save=config.lora_modules_to_save, 141 task_type="CAUSAL_LM" 142 ) 143 model = get_peft_model(base_model, lora_config) 144 145 # Load and preprocess dataset 146 training_state.update_state(message="Loading dataset...") 147 dataset = load_dataset("json", data_files=dataset_path, split="train") 148 149 def formatting_prompts_func(example): 150 output_texts_ids = tokenizer.apply_chat_template(example["messages"]) 151 output_texts = tokenizer.decode(output_texts_ids) 152 return output_texts 153 154 # Example dataset format: 155 # {"messages": [{"role": "user", "content": "What color is the sky?"}, 156 # {"role": "assistant", "content": "It is blue."}]} 157 158 trainer_args = SFTConfig( 159 output_dir=model_path, 160 per_device_train_batch_size=config.batch_size, 161 max_seq_length=config.max_length if config.max_length else tokenizer.model_max_length, 162 learning_rate=config.learning_rate, 163 num_train_epochs=config.epochs, 164 logging_dir=os.path.join(base_path, "logs"), 165 report_to="wandb", 166 run_name=f"{config.model_name}-{config.train_round}", 167 logging_steps=10, 168 save_strategy="epoch", 169 bf16=config.train_arg_bf16, 170 fp16=config.train_arg_fp16 171 ) 172 173 trainer = SFTTrainer( 174 model=model, 175 data_collator=DataCollatorForCompletionOnlyLM(config.response_template, tokenizer=tokenizer), 176 train_dataset=dataset, 177 callbacks=[ProgressCallback], 178 args=trainer_args, 179 formatting_func=formatting_prompts_func 180 ) 181 182 training_state.update_state(message="Training...") 183 trainer.train() 184 185 # Final state update 186 training_state.update_state( 187 status="completed", 188 message="training finished", 189 model_path=model_path 190 ) 191 192def merge_model( 193 model_name: str, 194 lora_path: str, 195 model_output: str, 196 save_tokenizer: bool = True, 197 save_config: bool = True, 198 safe_serialization: bool = True, 199 max_shard_size: str = "4GB", 200 torch_dtype: str = "auto", 201 push_to_hub: bool = False, 202): 203 base_model = AutoModelForCausalLM.from_pretrained( 204 model_name, 205 device_map="auto", 206 torch_dtype=torch_dtype, 207 trust_remote_code=True 208 ) 209 tokenizer = AutoTokenizer.from_pretrained(model_name) 210 211 lora_model = PeftModel.from_pretrained(base_model, lora_path) 212 213 merged_model = lora_model.merge_and_unload() 214 215 merged_model.save_pretrained( 216 model_output, 217 safe_serialization=safe_serialization, 218 max_shard_size=max_shard_size 219 ) 220 221 if save_tokenizer: 222 tokenizer.save_pretrained( 223 model_output, 224 legacy_format=True, 225 safe_serialization=True 226 ) 227 228 if save_config: 229 merged_model.config.save_pretrained(model_output) 230 231 if push_to_hub: 232 merged_model.push_to_hub(model_output) 233 tokenizer.push_to_hub(model_output) 234 235 return merged_model, tokenizer
19class TrainingState: 20 def __init__(self): 21 self.lock = threading.Lock() 22 self.current_task: Optional[Dict] = None 23 24 def update_state(self, **kwargs): 25 with self.lock: 26 if self.current_task is None: 27 self.current_task = {} 28 self.current_task.update(kwargs) 29 30 def get_state(self): 31 with self.lock: 32 return self.current_task.copy() if self.current_task else None
37class ProgressCallback(TrainerCallback): 38 def on_train_begin(self, args, state, control, **kwargs): 39 training_state.update_state( 40 total_epochs=state.num_train_epochs, 41 total_steps=state.max_steps 42 ) 43 44 def on_epoch_end(self, args, state, control, **kwargs): 45 training_state.update_state( 46 current_epoch=int(state.epoch), 47 current_step=state.global_step 48 ) 49 50 def on_step_end(self, args, state, control, **kwargs): 51 training_state.update_state( 52 current_step=state.global_step, 53 total_steps=state.max_steps 54 ) 55 56 def on_log(self, args, state, control, logs=None, **kwargs): 57 if torch.cuda.is_available(): 58 logs["gpu_alloc_mem"] = torch.cuda.memory_allocated() / 1024**2 59 logs["gpu_reserved_mem"] = torch.cuda.memory_reserved() / 1024**2 60 else: 61 logs["gpu_alloc_mem"] = 0 62 logs["gpu_reserved_mem"] = 0
A class for objects that will inspect the state of the training loop at some events and take some decisions. At each of those events the following arguments are available:
Args:
args ([TrainingArguments
]):
The training arguments used to instantiate the [Trainer
].
state ([TrainerState
]):
The current state of the [Trainer
].
control ([TrainerControl
]):
The object that is returned to the [Trainer
] and can be used to make some decisions.
model ([PreTrainedModel
] or torch.nn.Module
):
The model being trained.
tokenizer ([PreTrainedTokenizer
]):
The tokenizer used for encoding the data. This is deprecated in favour of processing_class
.
processing_class ([PreTrainedTokenizer
or BaseImageProcessor
or ProcessorMixin
or FeatureExtractionMixin
]):
The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor.
optimizer (torch.optim.Optimizer
):
The optimizer used for the training steps.
lr_scheduler (torch.optim.lr_scheduler.LambdaLR
):
The scheduler used for setting the learning rate.
train_dataloader (torch.utils.data.DataLoader
, *optional*):
The current dataloader used for training.
eval_dataloader (torch.utils.data.DataLoader
, optional):
The current dataloader used for evaluation.
metrics (Dict[str, float]
):
The metrics computed by the last evaluation phase.
Those are only accessible in the event `on_evaluate`.
logs (`Dict[str, float]`):
The values to log.
Those are only accessible in the event `on_log`.
The control
object is the only one that can be changed by the callback, in which case the event that changes it
should return the modified version.
The argument args
, state
and control
are positionals for all events, all the others are grouped in kwargs
.
You can unpack the ones you need in the signature of the event using them. As an example, see the code of the
simple [~transformers.PrinterCallback
].
Example:
class PrinterCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
_ = logs.pop("total_flos", None)
if state.is_local_process_zero:
print(logs)
38 def on_train_begin(self, args, state, control, **kwargs): 39 training_state.update_state( 40 total_epochs=state.num_train_epochs, 41 total_steps=state.max_steps 42 )
Event called at the beginning of training.
44 def on_epoch_end(self, args, state, control, **kwargs): 45 training_state.update_state( 46 current_epoch=int(state.epoch), 47 current_step=state.global_step 48 )
Event called at the end of an epoch.
50 def on_step_end(self, args, state, control, **kwargs): 51 training_state.update_state( 52 current_step=state.global_step, 53 total_steps=state.max_steps 54 )
Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.
56 def on_log(self, args, state, control, logs=None, **kwargs): 57 if torch.cuda.is_available(): 58 logs["gpu_alloc_mem"] = torch.cuda.memory_allocated() / 1024**2 59 logs["gpu_reserved_mem"] = torch.cuda.memory_reserved() / 1024**2 60 else: 61 logs["gpu_alloc_mem"] = 0 62 logs["gpu_reserved_mem"] = 0
Event called after logging the last logs.
67class TrainConfig(BaseModel): 68 model_name: str = "Qwen/Qwen2.5-0.5B" 69 lora_path: Optional[str] = None # Path to existing LoRA model, train from scratch if None, train from existing model if not None 70 lora_rank: int = 8 71 lora_alpha: int = 32 72 lora_dropout: float = 0.05 73 epochs: int = 3 74 batch_size: int = 4 75 learning_rate: float = 3e-4 76 max_length: Optional[int] = None 77 model_save_path: Optional[str] = None 78 response_template: Optional[str] = "<|im_start|>assistant\n" # Template for response generation 79 lora_target_modules: Union[List[str], str] = "all-linear" 80 lora_modules_to_save: Optional[List[str]] = None #["lm_head", "embed_token"] 81 tokenizer_padding_side: Optional[str] = "left" 82 attn_implementation: str = "flash_attention_2" 83 model_load_torch_dtype: str = "auto" 84 train_arg_bf16: bool = True 85 train_arg_fp16: bool = False 86 train_round: int = TRAIN_ROUND
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
__class_vars__: The names of the class variables defined on the model.
__private_attributes__: Metadata about the private attributes of the model.
__signature__: The synthesized __init__
[Signature
][inspect.Signature] of the model.
__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
__args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.
__pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
__class_vars__: The names of the class variables defined on the model.
__private_attributes__: Metadata about the private attributes of the model.
__signature__: The synthesized __init__
[Signature
][inspect.Signature] of the model.
__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
__args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.
__pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
94def train_model(config: TrainConfig, dataset_path: str, base_path: str): 95 # try: 96 training_state.update_state( 97 status="training", 98 message="Initializing model..." 99 ) 100 101 # Handle model save path 102 if config.model_save_path: 103 if os.path.isabs(config.model_save_path): 104 model_path = config.model_save_path 105 else: 106 model_path = os.path.join(base_path, config.model_save_path) 107 else: 108 model_path = os.path.join(base_path, "lora_finetuned") 109 110 # Ensure directory exists with write permission 111 os.makedirs(model_path, exist_ok=True) 112 if not os.access(model_path, os.W_OK): 113 raise RuntimeError(f"No write permission for {model_path}") 114 115 # Load model 116 tokenizer = AutoTokenizer.from_pretrained(config.model_name) 117 if config.tokenizer_padding_side is not None: 118 tokenizer.padding_side = config.tokenizer_padding_side 119 base_model = AutoModelForCausalLM.from_pretrained( 120 config.model_name, 121 device_map="auto", 122 torch_dtype=config.model_load_torch_dtype, 123 attn_implementation=config.attn_implementation 124 ) 125 126 if config.lora_path: 127 # load LoRA model 128 model = PeftModel.from_pretrained( 129 base_model, 130 config.lora_path, 131 device_map="auto" 132 ) 133 else: 134 # Configure new LoRA 135 lora_config = LoraConfig( 136 r=config.lora_rank, 137 lora_alpha=config.lora_alpha, 138 target_modules=config.lora_target_modules, 139 lora_dropout=config.lora_dropout, 140 bias="none", 141 modules_to_save=config.lora_modules_to_save, 142 task_type="CAUSAL_LM" 143 ) 144 model = get_peft_model(base_model, lora_config) 145 146 # Load and preprocess dataset 147 training_state.update_state(message="Loading dataset...") 148 dataset = load_dataset("json", data_files=dataset_path, split="train") 149 150 def formatting_prompts_func(example): 151 output_texts_ids = tokenizer.apply_chat_template(example["messages"]) 152 output_texts = tokenizer.decode(output_texts_ids) 153 return output_texts 154 155 # Example dataset format: 156 # {"messages": [{"role": "user", "content": "What color is the sky?"}, 157 # {"role": "assistant", "content": "It is blue."}]} 158 159 trainer_args = SFTConfig( 160 output_dir=model_path, 161 per_device_train_batch_size=config.batch_size, 162 max_seq_length=config.max_length if config.max_length else tokenizer.model_max_length, 163 learning_rate=config.learning_rate, 164 num_train_epochs=config.epochs, 165 logging_dir=os.path.join(base_path, "logs"), 166 report_to="wandb", 167 run_name=f"{config.model_name}-{config.train_round}", 168 logging_steps=10, 169 save_strategy="epoch", 170 bf16=config.train_arg_bf16, 171 fp16=config.train_arg_fp16 172 ) 173 174 trainer = SFTTrainer( 175 model=model, 176 data_collator=DataCollatorForCompletionOnlyLM(config.response_template, tokenizer=tokenizer), 177 train_dataset=dataset, 178 callbacks=[ProgressCallback], 179 args=trainer_args, 180 formatting_func=formatting_prompts_func 181 ) 182 183 training_state.update_state(message="Training...") 184 trainer.train() 185 186 # Final state update 187 training_state.update_state( 188 status="completed", 189 message="training finished", 190 model_path=model_path 191 )
193def merge_model( 194 model_name: str, 195 lora_path: str, 196 model_output: str, 197 save_tokenizer: bool = True, 198 save_config: bool = True, 199 safe_serialization: bool = True, 200 max_shard_size: str = "4GB", 201 torch_dtype: str = "auto", 202 push_to_hub: bool = False, 203): 204 base_model = AutoModelForCausalLM.from_pretrained( 205 model_name, 206 device_map="auto", 207 torch_dtype=torch_dtype, 208 trust_remote_code=True 209 ) 210 tokenizer = AutoTokenizer.from_pretrained(model_name) 211 212 lora_model = PeftModel.from_pretrained(base_model, lora_path) 213 214 merged_model = lora_model.merge_and_unload() 215 216 merged_model.save_pretrained( 217 model_output, 218 safe_serialization=safe_serialization, 219 max_shard_size=max_shard_size 220 ) 221 222 if save_tokenizer: 223 tokenizer.save_pretrained( 224 model_output, 225 legacy_format=True, 226 safe_serialization=True 227 ) 228 229 if save_config: 230 merged_model.config.save_pretrained(model_output) 231 232 if push_to_hub: 233 merged_model.push_to_hub(model_output) 234 tokenizer.push_to_hub(model_output) 235 236 return merged_model, tokenizer