Kkit.llm_utils.lora_fine_tune_server
1import os 2import argparse 3import threading 4import tempfile 5import wandb 6import json 7import torch 8from fastapi import FastAPI, HTTPException, UploadFile, File, status, Depends, Form 9from Kkit.llm_utils.fine_tune_utils import ( 10 train_model, 11 TrainConfig, 12 MergeConfig, 13 training_state, 14 merge_model, 15) 16 17 18# Initialize FastAPI application 19app = FastAPI(title="LoRA Training Service") 20 21def train_model_server(config: TrainConfig, dataset_path: str, base_path: str): 22 try: 23 train_model(config, dataset_path, base_path) 24 except Exception as e: 25 training_state.update_state( 26 status="error", 27 error=f"{type(e).__name__}: {str(e)}", 28 ) 29 finally: 30 # Cleanup temporary files 31 if os.path.exists(dataset_path): 32 try: 33 os.remove(dataset_path) 34 except Exception as e: 35 print(f"Error cleaning up temp file: {str(e)}") 36 37def parse_config(config: str = Form(...)) -> TrainConfig: 38 try: 39 return TrainConfig(**json.loads(config)) 40 except Exception as e: 41 raise HTTPException(422, detail=str(e)) 42 43# API endpoints 44@app.post( 45 "/train", 46 status_code=status.HTTP_202_ACCEPTED, 47 summary="Start a new training job" 48) 49def start_training( 50 config: TrainConfig = Depends(parse_config), 51 file: UploadFile = File(..., description="Training data (JSON format)") 52): 53 current_state = training_state.get_state() 54 if current_state and current_state.get("status") == "training": 55 raise HTTPException( 56 status_code=status.HTTP_409_CONFLICT, 57 detail="Training is already in progress" 58 ) 59 60 # Save uploaded file 61 try: 62 with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp_file: 63 content = file.file.read() 64 tmp_file.write(content) 65 dataset_path = tmp_file.name 66 except Exception as e: 67 raise HTTPException( 68 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 69 detail=f"File upload failed: {str(e)}" 70 ) 71 72 # Start training thread 73 training_state.update_state( 74 status="training", 75 config=config.model_dump(), 76 dataset_path=dataset_path, 77 ) 78 79 print(config) 80 81 thread = threading.Thread( 82 target=train_model_server, 83 args=(config, dataset_path, app.state.base_path) 84 ) 85 thread.start() 86 87 return {"message": "Training job started successfully"} 88 89@app.get("/status", summary="Get training status") 90def get_status(): 91 state = training_state.get_state() 92 if not state: 93 return {"status": "idle"} 94 95 if torch.cuda.is_available(): 96 allocated = torch.cuda.memory_allocated() / 1024**2 97 reserved = torch.cuda.memory_reserved() / 1024**2 98 else: 99 allocated = 0 100 reserved = 0 101 102 response = { 103 "status": state.get("status"), 104 "message": state.get("message"), 105 "current_step": state.get("current_step", 0), 106 "total_steps": state.get("total_steps", 0), 107 "current_epoch": state.get("current_epoch", 0), 108 "total_epochs": state.get("total_epochs", 0), 109 "allocated_gpu_memory": allocated, 110 "reserved_gpu_memory": reserved, 111 "error": state.get("error", "None"), 112 } 113 114 if state.get("status") == "completed": 115 response["model_path"] = state.get("model_path") 116 117 return response 118 119@app.post("/merge", summary="合并LoRA适配器到基础模型") 120def merge(config: MergeConfig): 121 try: 122 merge_model( 123 model_name=config.model_name, 124 lora_path=config.lora_path, 125 model_output=config.model_output 126 ) 127 128 return { 129 "status": "success", 130 "message": "merging finished", 131 "output_path": config.model_output 132 } 133 134 except Exception as e: 135 raise HTTPException( 136 status_code=500, 137 detail={ 138 "status": "error", 139 "message": f"merging failed: {str(e)}" 140 } 141 ) 142 143def main(): 144 parser = argparse.ArgumentParser() 145 parser.add_argument("--host", type=str, default="0.0.0.0", help="Server listening host") 146 parser.add_argument("--port", type=int, default=8000, help="Server listening port") 147 parser.add_argument("--base_path", required=True, type=str, help="Base directory for model outputs") 148 args = parser.parse_args() 149 150 # Validate base path 151 if not os.path.exists(args.base_path): 152 os.makedirs(args.base_path, exist_ok=True) 153 if not os.path.isdir(args.base_path): 154 raise ValueError(f"Base path {args.base_path} is not a directory") 155 156 app.state.base_path = os.path.abspath(args.base_path) 157 158 # Start server 159 import uvicorn 160 uvicorn.run(app, host=args.host, port=args.port) 161 162# Main entry point 163if __name__ == "__main__": 164 main()
app =
<fastapi.applications.FastAPI object>
def
train_model_server( config: Kkit.llm_utils.fine_tune_utils.TrainConfig, dataset_path: str, base_path: str):
22def train_model_server(config: TrainConfig, dataset_path: str, base_path: str): 23 try: 24 train_model(config, dataset_path, base_path) 25 except Exception as e: 26 training_state.update_state( 27 status="error", 28 error=f"{type(e).__name__}: {str(e)}", 29 ) 30 finally: 31 # Cleanup temporary files 32 if os.path.exists(dataset_path): 33 try: 34 os.remove(dataset_path) 35 except Exception as e: 36 print(f"Error cleaning up temp file: {str(e)}")
def
parse_config( config: str = Form(PydanticUndefined)) -> Kkit.llm_utils.fine_tune_utils.TrainConfig:
@app.post('/train', status_code=status.HTTP_202_ACCEPTED, summary='Start a new training job')
def
start_training( config: Kkit.llm_utils.fine_tune_utils.TrainConfig = Depends(parse_config), file: fastapi.datastructures.UploadFile = File(PydanticUndefined)):
45@app.post( 46 "/train", 47 status_code=status.HTTP_202_ACCEPTED, 48 summary="Start a new training job" 49) 50def start_training( 51 config: TrainConfig = Depends(parse_config), 52 file: UploadFile = File(..., description="Training data (JSON format)") 53): 54 current_state = training_state.get_state() 55 if current_state and current_state.get("status") == "training": 56 raise HTTPException( 57 status_code=status.HTTP_409_CONFLICT, 58 detail="Training is already in progress" 59 ) 60 61 # Save uploaded file 62 try: 63 with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp_file: 64 content = file.file.read() 65 tmp_file.write(content) 66 dataset_path = tmp_file.name 67 except Exception as e: 68 raise HTTPException( 69 status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 70 detail=f"File upload failed: {str(e)}" 71 ) 72 73 # Start training thread 74 training_state.update_state( 75 status="training", 76 config=config.model_dump(), 77 dataset_path=dataset_path, 78 ) 79 80 print(config) 81 82 thread = threading.Thread( 83 target=train_model_server, 84 args=(config, dataset_path, app.state.base_path) 85 ) 86 thread.start() 87 88 return {"message": "Training job started successfully"}
@app.get('/status', summary='Get training status')
def
get_status():
90@app.get("/status", summary="Get training status") 91def get_status(): 92 state = training_state.get_state() 93 if not state: 94 return {"status": "idle"} 95 96 if torch.cuda.is_available(): 97 allocated = torch.cuda.memory_allocated() / 1024**2 98 reserved = torch.cuda.memory_reserved() / 1024**2 99 else: 100 allocated = 0 101 reserved = 0 102 103 response = { 104 "status": state.get("status"), 105 "message": state.get("message"), 106 "current_step": state.get("current_step", 0), 107 "total_steps": state.get("total_steps", 0), 108 "current_epoch": state.get("current_epoch", 0), 109 "total_epochs": state.get("total_epochs", 0), 110 "allocated_gpu_memory": allocated, 111 "reserved_gpu_memory": reserved, 112 "error": state.get("error", "None"), 113 } 114 115 if state.get("status") == "completed": 116 response["model_path"] = state.get("model_path") 117 118 return response
@app.post('/merge', summary='合并LoRA适配器到基础模型')
def
merge(config: Kkit.llm_utils.fine_tune_utils.MergeConfig):
120@app.post("/merge", summary="合并LoRA适配器到基础模型") 121def merge(config: MergeConfig): 122 try: 123 merge_model( 124 model_name=config.model_name, 125 lora_path=config.lora_path, 126 model_output=config.model_output 127 ) 128 129 return { 130 "status": "success", 131 "message": "merging finished", 132 "output_path": config.model_output 133 } 134 135 except Exception as e: 136 raise HTTPException( 137 status_code=500, 138 detail={ 139 "status": "error", 140 "message": f"merging failed: {str(e)}" 141 } 142 )
def
main():
144def main(): 145 parser = argparse.ArgumentParser() 146 parser.add_argument("--host", type=str, default="0.0.0.0", help="Server listening host") 147 parser.add_argument("--port", type=int, default=8000, help="Server listening port") 148 parser.add_argument("--base_path", required=True, type=str, help="Base directory for model outputs") 149 args = parser.parse_args() 150 151 # Validate base path 152 if not os.path.exists(args.base_path): 153 os.makedirs(args.base_path, exist_ok=True) 154 if not os.path.isdir(args.base_path): 155 raise ValueError(f"Base path {args.base_path} is not a directory") 156 157 app.state.base_path = os.path.abspath(args.base_path) 158 159 # Start server 160 import uvicorn 161 uvicorn.run(app, host=args.host, port=args.port)