Kkit.llm_utils.fine_tune_utils

  1from pydantic import BaseModel
  2from typing import Dict, Optional
  3import torch
  4from typing import List, Union
  5from transformers import (
  6    AutoModelForCausalLM,
  7    AutoTokenizer,
  8    TrainerCallback
  9)
 10import threading
 11import os
 12from datasets import load_dataset
 13from peft import LoraConfig, get_peft_model, PeftModel
 14from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
 15
 16
 17# Global training state tracker
 18class TrainingState:
 19    def __init__(self):
 20        self.lock = threading.Lock()
 21        self.current_task: Optional[Dict] = None
 22
 23    def update_state(self, **kwargs):
 24        with self.lock:
 25            if self.current_task is None:
 26                self.current_task = {}
 27            self.current_task.update(kwargs)
 28
 29    def get_state(self):
 30        with self.lock:
 31            return self.current_task.copy() if self.current_task else None
 32
 33training_state = TrainingState()
 34
 35# Training progress callback
 36class ProgressCallback(TrainerCallback):
 37    def on_train_begin(self, args, state, control, **kwargs):
 38        training_state.update_state(
 39            total_epochs=state.num_train_epochs,
 40            total_steps=state.max_steps
 41        )
 42
 43    def on_epoch_end(self, args, state, control, **kwargs):
 44        training_state.update_state(
 45            current_epoch=int(state.epoch),
 46            current_step=state.global_step
 47        )
 48
 49    def on_step_end(self, args, state, control, **kwargs):
 50        training_state.update_state(
 51            current_step=state.global_step,
 52            total_steps=state.max_steps
 53        )
 54
 55    def on_log(self, args, state, control, logs=None, **kwargs):
 56        if torch.cuda.is_available():
 57            logs["gpu_alloc_mem"] = torch.cuda.memory_allocated() / 1024**2
 58            logs["gpu_reserved_mem"] = torch.cuda.memory_reserved() / 1024**2
 59        else:
 60            logs["gpu_alloc_mem"] = 0
 61            logs["gpu_reserved_mem"] = 0
 62
 63TRAIN_ROUND = 0
 64
 65# Training configuration model
 66class TrainConfig(BaseModel):
 67    model_name: str = "Qwen/Qwen2.5-0.5B"
 68    lora_path: Optional[str] = None      # Path to existing LoRA model, train from scratch if None, train from existing model if not None
 69    lora_rank: int = 8
 70    lora_alpha: int = 32
 71    lora_dropout: float = 0.05
 72    epochs: int = 3
 73    batch_size: int = 4
 74    learning_rate: float = 3e-4
 75    max_length: Optional[int] = None
 76    model_save_path: Optional[str] = None
 77    response_template: Optional[str] = "<|im_start|>assistant\n"   # Template for response generation
 78    lora_target_modules: Union[List[str], str] = "all-linear"
 79    lora_modules_to_save: Optional[List[str]] = None #["lm_head", "embed_token"]
 80    tokenizer_padding_side: Optional[str] = "left"
 81    attn_implementation: str = "flash_attention_2"
 82    model_load_torch_dtype: str = "auto"
 83    train_arg_bf16: bool = True
 84    train_arg_fp16: bool = False
 85    train_round: int = TRAIN_ROUND
 86
 87class MergeConfig(BaseModel):
 88    model_name: str
 89    lora_path: str
 90    model_output: str
 91
 92# Core training function
 93def train_model(config: TrainConfig, dataset_path: str, base_path: str):
 94    # try:
 95    training_state.update_state(
 96        status="training",
 97        message="Initializing model..."
 98    )
 99
100    # Handle model save path
101    if config.model_save_path:
102        if os.path.isabs(config.model_save_path):
103            model_path = config.model_save_path
104        else:
105            model_path = os.path.join(base_path, config.model_save_path)
106    else:
107        model_path = os.path.join(base_path, "lora_finetuned")
108    
109    # Ensure directory exists with write permission
110    os.makedirs(model_path, exist_ok=True)
111    if not os.access(model_path, os.W_OK):
112        raise RuntimeError(f"No write permission for {model_path}")
113
114    # Load model
115    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
116    if config.tokenizer_padding_side is not None:
117        tokenizer.padding_side = config.tokenizer_padding_side
118    base_model = AutoModelForCausalLM.from_pretrained(
119        config.model_name,
120        device_map="auto",
121        torch_dtype=config.model_load_torch_dtype,
122        attn_implementation=config.attn_implementation
123    )
124
125    if config.lora_path:
126        # load LoRA model
127        model = PeftModel.from_pretrained(
128            base_model,
129            config.lora_path,
130            device_map="auto"
131        )
132    else:
133        # Configure new LoRA
134        lora_config = LoraConfig(
135            r=config.lora_rank,
136            lora_alpha=config.lora_alpha,
137            target_modules=config.lora_target_modules,
138            lora_dropout=config.lora_dropout,
139            bias="none",
140            modules_to_save=config.lora_modules_to_save,
141            task_type="CAUSAL_LM"
142        )
143        model = get_peft_model(base_model, lora_config)
144
145    # Load and preprocess dataset
146    training_state.update_state(message="Loading dataset...")
147    dataset = load_dataset("json", data_files=dataset_path, split="train")
148
149    def formatting_prompts_func(example):
150        output_texts_ids = tokenizer.apply_chat_template(example["messages"])
151        output_texts = tokenizer.decode(output_texts_ids)
152        return output_texts
153    
154    # Example dataset format:
155    # {"messages": [{"role": "user", "content": "What color is the sky?"},
156    #       {"role": "assistant", "content": "It is blue."}]}
157
158    trainer_args = SFTConfig(
159        output_dir=model_path,
160        per_device_train_batch_size=config.batch_size,
161        max_seq_length=config.max_length if config.max_length else tokenizer.model_max_length,
162        learning_rate=config.learning_rate,
163        num_train_epochs=config.epochs,
164        logging_dir=os.path.join(base_path, "logs"),
165        report_to="wandb",
166        run_name=f"{config.model_name}-{config.train_round}",
167        logging_steps=10,
168        save_strategy="epoch",
169        bf16=config.train_arg_bf16,
170        fp16=config.train_arg_fp16
171    )
172
173    trainer = SFTTrainer(
174        model=model,
175        data_collator=DataCollatorForCompletionOnlyLM(config.response_template, tokenizer=tokenizer),
176        train_dataset=dataset,
177        callbacks=[ProgressCallback],
178        args=trainer_args,
179        formatting_func=formatting_prompts_func
180    )
181
182    training_state.update_state(message="Training...")
183    trainer.train()
184
185    # Final state update
186    training_state.update_state(
187        status="completed",
188        message="training finished",
189        model_path=model_path
190    )
191
192def merge_model(
193    model_name: str,
194    lora_path: str,
195    model_output: str,
196    save_tokenizer: bool = True,
197    save_config: bool = True,
198    safe_serialization: bool = True,
199    max_shard_size: str = "4GB",
200    torch_dtype: str = "auto",
201    push_to_hub: bool = False,
202):
203    base_model = AutoModelForCausalLM.from_pretrained(
204        model_name,
205        device_map="auto",
206        torch_dtype=torch_dtype,
207        trust_remote_code=True
208    )
209    tokenizer = AutoTokenizer.from_pretrained(model_name)
210
211    lora_model = PeftModel.from_pretrained(base_model, lora_path)
212    
213    merged_model = lora_model.merge_and_unload()
214
215    merged_model.save_pretrained(
216        model_output,
217        safe_serialization=safe_serialization,
218        max_shard_size=max_shard_size
219    )
220
221    if save_tokenizer:
222        tokenizer.save_pretrained(
223            model_output,
224            legacy_format=True,
225            safe_serialization=True
226        )
227
228    if save_config:
229        merged_model.config.save_pretrained(model_output)
230
231    if push_to_hub:
232        merged_model.push_to_hub(model_output)
233        tokenizer.push_to_hub(model_output)
234
235    return merged_model, tokenizer
class TrainingState:
19class TrainingState:
20    def __init__(self):
21        self.lock = threading.Lock()
22        self.current_task: Optional[Dict] = None
23
24    def update_state(self, **kwargs):
25        with self.lock:
26            if self.current_task is None:
27                self.current_task = {}
28            self.current_task.update(kwargs)
29
30    def get_state(self):
31        with self.lock:
32            return self.current_task.copy() if self.current_task else None
lock
current_task: Optional[Dict]
def update_state(self, **kwargs):
24    def update_state(self, **kwargs):
25        with self.lock:
26            if self.current_task is None:
27                self.current_task = {}
28            self.current_task.update(kwargs)
def get_state(self):
30    def get_state(self):
31        with self.lock:
32            return self.current_task.copy() if self.current_task else None
training_state = <TrainingState object>
class ProgressCallback(transformers.trainer_callback.TrainerCallback):
37class ProgressCallback(TrainerCallback):
38    def on_train_begin(self, args, state, control, **kwargs):
39        training_state.update_state(
40            total_epochs=state.num_train_epochs,
41            total_steps=state.max_steps
42        )
43
44    def on_epoch_end(self, args, state, control, **kwargs):
45        training_state.update_state(
46            current_epoch=int(state.epoch),
47            current_step=state.global_step
48        )
49
50    def on_step_end(self, args, state, control, **kwargs):
51        training_state.update_state(
52            current_step=state.global_step,
53            total_steps=state.max_steps
54        )
55
56    def on_log(self, args, state, control, logs=None, **kwargs):
57        if torch.cuda.is_available():
58            logs["gpu_alloc_mem"] = torch.cuda.memory_allocated() / 1024**2
59            logs["gpu_reserved_mem"] = torch.cuda.memory_reserved() / 1024**2
60        else:
61            logs["gpu_alloc_mem"] = 0
62            logs["gpu_reserved_mem"] = 0

A class for objects that will inspect the state of the training loop at some events and take some decisions. At each of those events the following arguments are available:

Args: args ([TrainingArguments]): The training arguments used to instantiate the [Trainer]. state ([TrainerState]): The current state of the [Trainer]. control ([TrainerControl]): The object that is returned to the [Trainer] and can be used to make some decisions. model ([PreTrainedModel] or torch.nn.Module): The model being trained. tokenizer ([PreTrainedTokenizer]): The tokenizer used for encoding the data. This is deprecated in favour of processing_class. processing_class ([PreTrainedTokenizer or BaseImageProcessor or ProcessorMixin or FeatureExtractionMixin]): The processing class used for encoding the data. Can be a tokenizer, a processor, an image processor or a feature extractor. optimizer (torch.optim.Optimizer): The optimizer used for the training steps. lr_scheduler (torch.optim.lr_scheduler.LambdaLR): The scheduler used for setting the learning rate. train_dataloader (torch.utils.data.DataLoader, *optional*): The current dataloader used for training. eval_dataloader (torch.utils.data.DataLoader, optional): The current dataloader used for evaluation. metrics (Dict[str, float]): The metrics computed by the last evaluation phase.

    Those are only accessible in the event `on_evaluate`.
logs  (`Dict[str, float]`):
    The values to log.

    Those are only accessible in the event `on_log`.

The control object is the only one that can be changed by the callback, in which case the event that changes it should return the modified version.

The argument args, state and control are positionals for all events, all the others are grouped in kwargs. You can unpack the ones you need in the signature of the event using them. As an example, see the code of the simple [~transformers.PrinterCallback].

Example:

class PrinterCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        _ = logs.pop("total_flos", None)
        if state.is_local_process_zero:
            print(logs)
def on_train_begin(self, args, state, control, **kwargs):
38    def on_train_begin(self, args, state, control, **kwargs):
39        training_state.update_state(
40            total_epochs=state.num_train_epochs,
41            total_steps=state.max_steps
42        )

Event called at the beginning of training.

def on_epoch_end(self, args, state, control, **kwargs):
44    def on_epoch_end(self, args, state, control, **kwargs):
45        training_state.update_state(
46            current_epoch=int(state.epoch),
47            current_step=state.global_step
48        )

Event called at the end of an epoch.

def on_step_end(self, args, state, control, **kwargs):
50    def on_step_end(self, args, state, control, **kwargs):
51        training_state.update_state(
52            current_step=state.global_step,
53            total_steps=state.max_steps
54        )

Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.

def on_log(self, args, state, control, logs=None, **kwargs):
56    def on_log(self, args, state, control, logs=None, **kwargs):
57        if torch.cuda.is_available():
58            logs["gpu_alloc_mem"] = torch.cuda.memory_allocated() / 1024**2
59            logs["gpu_reserved_mem"] = torch.cuda.memory_reserved() / 1024**2
60        else:
61            logs["gpu_alloc_mem"] = 0
62            logs["gpu_reserved_mem"] = 0

Event called after logging the last logs.

TRAIN_ROUND = 0
class TrainConfig(pydantic.main.BaseModel):
67class TrainConfig(BaseModel):
68    model_name: str = "Qwen/Qwen2.5-0.5B"
69    lora_path: Optional[str] = None      # Path to existing LoRA model, train from scratch if None, train from existing model if not None
70    lora_rank: int = 8
71    lora_alpha: int = 32
72    lora_dropout: float = 0.05
73    epochs: int = 3
74    batch_size: int = 4
75    learning_rate: float = 3e-4
76    max_length: Optional[int] = None
77    model_save_path: Optional[str] = None
78    response_template: Optional[str] = "<|im_start|>assistant\n"   # Template for response generation
79    lora_target_modules: Union[List[str], str] = "all-linear"
80    lora_modules_to_save: Optional[List[str]] = None #["lm_head", "embed_token"]
81    tokenizer_padding_side: Optional[str] = "left"
82    attn_implementation: str = "flash_attention_2"
83    model_load_torch_dtype: str = "auto"
84    train_arg_bf16: bool = True
85    train_arg_fp16: bool = False
86    train_round: int = TRAIN_ROUND

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes: __class_vars__: The names of the class variables defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.

__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
    This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
    __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.

__pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.

__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
    is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
model_name: str
lora_path: Optional[str]
lora_rank: int
lora_alpha: int
lora_dropout: float
epochs: int
batch_size: int
learning_rate: float
max_length: Optional[int]
model_save_path: Optional[str]
response_template: Optional[str]
lora_target_modules: Union[List[str], str]
lora_modules_to_save: Optional[List[str]]
tokenizer_padding_side: Optional[str]
attn_implementation: str
model_load_torch_dtype: str
train_arg_bf16: bool
train_arg_fp16: bool
train_round: int
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class MergeConfig(pydantic.main.BaseModel):
88class MergeConfig(BaseModel):
89    model_name: str
90    lora_path: str
91    model_output: str

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes: __class_vars__: The names of the class variables defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.

__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
    This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
    __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.

__pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.

__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
    is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
model_name: str
lora_path: str
model_output: str
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

def train_model( config: TrainConfig, dataset_path: str, base_path: str):
 94def train_model(config: TrainConfig, dataset_path: str, base_path: str):
 95    # try:
 96    training_state.update_state(
 97        status="training",
 98        message="Initializing model..."
 99    )
100
101    # Handle model save path
102    if config.model_save_path:
103        if os.path.isabs(config.model_save_path):
104            model_path = config.model_save_path
105        else:
106            model_path = os.path.join(base_path, config.model_save_path)
107    else:
108        model_path = os.path.join(base_path, "lora_finetuned")
109    
110    # Ensure directory exists with write permission
111    os.makedirs(model_path, exist_ok=True)
112    if not os.access(model_path, os.W_OK):
113        raise RuntimeError(f"No write permission for {model_path}")
114
115    # Load model
116    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
117    if config.tokenizer_padding_side is not None:
118        tokenizer.padding_side = config.tokenizer_padding_side
119    base_model = AutoModelForCausalLM.from_pretrained(
120        config.model_name,
121        device_map="auto",
122        torch_dtype=config.model_load_torch_dtype,
123        attn_implementation=config.attn_implementation
124    )
125
126    if config.lora_path:
127        # load LoRA model
128        model = PeftModel.from_pretrained(
129            base_model,
130            config.lora_path,
131            device_map="auto"
132        )
133    else:
134        # Configure new LoRA
135        lora_config = LoraConfig(
136            r=config.lora_rank,
137            lora_alpha=config.lora_alpha,
138            target_modules=config.lora_target_modules,
139            lora_dropout=config.lora_dropout,
140            bias="none",
141            modules_to_save=config.lora_modules_to_save,
142            task_type="CAUSAL_LM"
143        )
144        model = get_peft_model(base_model, lora_config)
145
146    # Load and preprocess dataset
147    training_state.update_state(message="Loading dataset...")
148    dataset = load_dataset("json", data_files=dataset_path, split="train")
149
150    def formatting_prompts_func(example):
151        output_texts_ids = tokenizer.apply_chat_template(example["messages"])
152        output_texts = tokenizer.decode(output_texts_ids)
153        return output_texts
154    
155    # Example dataset format:
156    # {"messages": [{"role": "user", "content": "What color is the sky?"},
157    #       {"role": "assistant", "content": "It is blue."}]}
158
159    trainer_args = SFTConfig(
160        output_dir=model_path,
161        per_device_train_batch_size=config.batch_size,
162        max_seq_length=config.max_length if config.max_length else tokenizer.model_max_length,
163        learning_rate=config.learning_rate,
164        num_train_epochs=config.epochs,
165        logging_dir=os.path.join(base_path, "logs"),
166        report_to="wandb",
167        run_name=f"{config.model_name}-{config.train_round}",
168        logging_steps=10,
169        save_strategy="epoch",
170        bf16=config.train_arg_bf16,
171        fp16=config.train_arg_fp16
172    )
173
174    trainer = SFTTrainer(
175        model=model,
176        data_collator=DataCollatorForCompletionOnlyLM(config.response_template, tokenizer=tokenizer),
177        train_dataset=dataset,
178        callbacks=[ProgressCallback],
179        args=trainer_args,
180        formatting_func=formatting_prompts_func
181    )
182
183    training_state.update_state(message="Training...")
184    trainer.train()
185
186    # Final state update
187    training_state.update_state(
188        status="completed",
189        message="training finished",
190        model_path=model_path
191    )
def merge_model( model_name: str, lora_path: str, model_output: str, save_tokenizer: bool = True, save_config: bool = True, safe_serialization: bool = True, max_shard_size: str = '4GB', torch_dtype: str = 'auto', push_to_hub: bool = False):
193def merge_model(
194    model_name: str,
195    lora_path: str,
196    model_output: str,
197    save_tokenizer: bool = True,
198    save_config: bool = True,
199    safe_serialization: bool = True,
200    max_shard_size: str = "4GB",
201    torch_dtype: str = "auto",
202    push_to_hub: bool = False,
203):
204    base_model = AutoModelForCausalLM.from_pretrained(
205        model_name,
206        device_map="auto",
207        torch_dtype=torch_dtype,
208        trust_remote_code=True
209    )
210    tokenizer = AutoTokenizer.from_pretrained(model_name)
211
212    lora_model = PeftModel.from_pretrained(base_model, lora_path)
213    
214    merged_model = lora_model.merge_and_unload()
215
216    merged_model.save_pretrained(
217        model_output,
218        safe_serialization=safe_serialization,
219        max_shard_size=max_shard_size
220    )
221
222    if save_tokenizer:
223        tokenizer.save_pretrained(
224            model_output,
225            legacy_format=True,
226            safe_serialization=True
227        )
228
229    if save_config:
230        merged_model.config.save_pretrained(model_output)
231
232    if push_to_hub:
233        merged_model.push_to_hub(model_output)
234        tokenizer.push_to_hub(model_output)
235
236    return merged_model, tokenizer