Kkit.llm_utils.lora_fine_tune_server

  1import os
  2import argparse
  3import threading
  4import tempfile
  5import wandb
  6import json
  7import torch
  8from fastapi import FastAPI, HTTPException, UploadFile, File, status, Depends, Form
  9from Kkit.llm_utils.fine_tune_utils import (
 10    train_model,
 11    TrainConfig,
 12    MergeConfig,
 13    training_state,
 14    merge_model,
 15)
 16
 17
 18# Initialize FastAPI application
 19app = FastAPI(title="LoRA Training Service")
 20
 21def train_model_server(config: TrainConfig, dataset_path: str, base_path: str):
 22    try:
 23        train_model(config, dataset_path, base_path)
 24    except Exception as e:
 25        training_state.update_state(
 26            status="error",
 27            error=f"{type(e).__name__}: {str(e)}",
 28        )
 29    finally:
 30        # Cleanup temporary files
 31        if os.path.exists(dataset_path):
 32            try:
 33                os.remove(dataset_path)
 34            except Exception as e:
 35                print(f"Error cleaning up temp file: {str(e)}")
 36
 37def parse_config(config: str = Form(...)) -> TrainConfig:
 38    try:
 39        return TrainConfig(**json.loads(config))
 40    except Exception as e:
 41        raise HTTPException(422, detail=str(e))
 42
 43# API endpoints
 44@app.post(
 45    "/train",
 46    status_code=status.HTTP_202_ACCEPTED,
 47    summary="Start a new training job"
 48)
 49def start_training(
 50    config: TrainConfig = Depends(parse_config),
 51    file: UploadFile = File(..., description="Training data (JSON format)")
 52):
 53    current_state = training_state.get_state()
 54    if current_state and current_state.get("status") == "training":
 55        raise HTTPException(
 56            status_code=status.HTTP_409_CONFLICT,
 57            detail="Training is already in progress"
 58        )
 59
 60    # Save uploaded file
 61    try:
 62        with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp_file:
 63            content = file.file.read()
 64            tmp_file.write(content)
 65            dataset_path = tmp_file.name
 66    except Exception as e:
 67        raise HTTPException(
 68            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
 69            detail=f"File upload failed: {str(e)}"
 70        )
 71
 72    # Start training thread
 73    training_state.update_state(
 74        status="training",
 75        config=config.model_dump(),
 76        dataset_path=dataset_path,
 77    )
 78
 79    print(config)
 80
 81    thread = threading.Thread(
 82        target=train_model_server,
 83        args=(config, dataset_path, app.state.base_path)
 84    )
 85    thread.start()
 86
 87    return {"message": "Training job started successfully"}
 88
 89@app.get("/status", summary="Get training status")
 90def get_status():
 91    state = training_state.get_state()
 92    if not state:
 93        return {"status": "idle"}
 94    
 95    if torch.cuda.is_available():
 96        allocated = torch.cuda.memory_allocated() / 1024**2
 97        reserved = torch.cuda.memory_reserved() / 1024**2
 98    else:
 99        allocated = 0
100        reserved = 0
101
102    response = {
103        "status": state.get("status"),
104        "message": state.get("message"),
105        "current_step": state.get("current_step", 0),
106        "total_steps": state.get("total_steps", 0),
107        "current_epoch": state.get("current_epoch", 0),
108        "total_epochs": state.get("total_epochs", 0),
109        "allocated_gpu_memory": allocated,
110        "reserved_gpu_memory": reserved,
111        "error": state.get("error", "None"),
112    }
113    
114    if state.get("status") == "completed":
115        response["model_path"] = state.get("model_path")
116    
117    return response
118
119@app.post("/merge", summary="合并LoRA适配器到基础模型")
120def merge(config: MergeConfig):
121    try:
122        merge_model(
123            model_name=config.model_name,
124            lora_path=config.lora_path,
125            model_output=config.model_output
126        )
127        
128        return {
129            "status": "success",
130            "message": "merging finished",
131            "output_path": config.model_output
132        }
133        
134    except Exception as e:
135        raise HTTPException(
136            status_code=500,
137            detail={
138                "status": "error",
139                "message": f"merging failed: {str(e)}"
140            }
141        )
142
143def main():
144    parser = argparse.ArgumentParser()
145    parser.add_argument("--host", type=str, default="0.0.0.0", help="Server listening host")
146    parser.add_argument("--port", type=int, default=8000, help="Server listening port")
147    parser.add_argument("--base_path", required=True, type=str, help="Base directory for model outputs")
148    args = parser.parse_args()
149
150    # Validate base path
151    if not os.path.exists(args.base_path):
152        os.makedirs(args.base_path, exist_ok=True)
153    if not os.path.isdir(args.base_path):
154        raise ValueError(f"Base path {args.base_path} is not a directory")
155
156    app.state.base_path = os.path.abspath(args.base_path)
157
158    # Start server
159    import uvicorn
160    uvicorn.run(app, host=args.host, port=args.port)
161
162# Main entry point
163if __name__ == "__main__":
164    main()
app = <fastapi.applications.FastAPI object>
def train_model_server( config: Kkit.llm_utils.fine_tune_utils.TrainConfig, dataset_path: str, base_path: str):
22def train_model_server(config: TrainConfig, dataset_path: str, base_path: str):
23    try:
24        train_model(config, dataset_path, base_path)
25    except Exception as e:
26        training_state.update_state(
27            status="error",
28            error=f"{type(e).__name__}: {str(e)}",
29        )
30    finally:
31        # Cleanup temporary files
32        if os.path.exists(dataset_path):
33            try:
34                os.remove(dataset_path)
35            except Exception as e:
36                print(f"Error cleaning up temp file: {str(e)}")
def parse_config( config: str = Form(PydanticUndefined)) -> Kkit.llm_utils.fine_tune_utils.TrainConfig:
38def parse_config(config: str = Form(...)) -> TrainConfig:
39    try:
40        return TrainConfig(**json.loads(config))
41    except Exception as e:
42        raise HTTPException(422, detail=str(e))
@app.post('/train', status_code=status.HTTP_202_ACCEPTED, summary='Start a new training job')
def start_training( config: Kkit.llm_utils.fine_tune_utils.TrainConfig = Depends(parse_config), file: fastapi.datastructures.UploadFile = File(PydanticUndefined)):
45@app.post(
46    "/train",
47    status_code=status.HTTP_202_ACCEPTED,
48    summary="Start a new training job"
49)
50def start_training(
51    config: TrainConfig = Depends(parse_config),
52    file: UploadFile = File(..., description="Training data (JSON format)")
53):
54    current_state = training_state.get_state()
55    if current_state and current_state.get("status") == "training":
56        raise HTTPException(
57            status_code=status.HTTP_409_CONFLICT,
58            detail="Training is already in progress"
59        )
60
61    # Save uploaded file
62    try:
63        with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp_file:
64            content = file.file.read()
65            tmp_file.write(content)
66            dataset_path = tmp_file.name
67    except Exception as e:
68        raise HTTPException(
69            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
70            detail=f"File upload failed: {str(e)}"
71        )
72
73    # Start training thread
74    training_state.update_state(
75        status="training",
76        config=config.model_dump(),
77        dataset_path=dataset_path,
78    )
79
80    print(config)
81
82    thread = threading.Thread(
83        target=train_model_server,
84        args=(config, dataset_path, app.state.base_path)
85    )
86    thread.start()
87
88    return {"message": "Training job started successfully"}
@app.get('/status', summary='Get training status')
def get_status():
 90@app.get("/status", summary="Get training status")
 91def get_status():
 92    state = training_state.get_state()
 93    if not state:
 94        return {"status": "idle"}
 95    
 96    if torch.cuda.is_available():
 97        allocated = torch.cuda.memory_allocated() / 1024**2
 98        reserved = torch.cuda.memory_reserved() / 1024**2
 99    else:
100        allocated = 0
101        reserved = 0
102
103    response = {
104        "status": state.get("status"),
105        "message": state.get("message"),
106        "current_step": state.get("current_step", 0),
107        "total_steps": state.get("total_steps", 0),
108        "current_epoch": state.get("current_epoch", 0),
109        "total_epochs": state.get("total_epochs", 0),
110        "allocated_gpu_memory": allocated,
111        "reserved_gpu_memory": reserved,
112        "error": state.get("error", "None"),
113    }
114    
115    if state.get("status") == "completed":
116        response["model_path"] = state.get("model_path")
117    
118    return response
@app.post('/merge', summary='合并LoRA适配器到基础模型')
def merge(config: Kkit.llm_utils.fine_tune_utils.MergeConfig):
120@app.post("/merge", summary="合并LoRA适配器到基础模型")
121def merge(config: MergeConfig):
122    try:
123        merge_model(
124            model_name=config.model_name,
125            lora_path=config.lora_path,
126            model_output=config.model_output
127        )
128        
129        return {
130            "status": "success",
131            "message": "merging finished",
132            "output_path": config.model_output
133        }
134        
135    except Exception as e:
136        raise HTTPException(
137            status_code=500,
138            detail={
139                "status": "error",
140                "message": f"merging failed: {str(e)}"
141            }
142        )
def main():
144def main():
145    parser = argparse.ArgumentParser()
146    parser.add_argument("--host", type=str, default="0.0.0.0", help="Server listening host")
147    parser.add_argument("--port", type=int, default=8000, help="Server listening port")
148    parser.add_argument("--base_path", required=True, type=str, help="Base directory for model outputs")
149    args = parser.parse_args()
150
151    # Validate base path
152    if not os.path.exists(args.base_path):
153        os.makedirs(args.base_path, exist_ok=True)
154    if not os.path.isdir(args.base_path):
155        raise ValueError(f"Base path {args.base_path} is not a directory")
156
157    app.state.base_path = os.path.abspath(args.base_path)
158
159    # Start server
160    import uvicorn
161    uvicorn.run(app, host=args.host, port=args.port)