Kkit.llm_utils.llamacpp_wrapper_server

  1import subprocess
  2import json
  3import time
  4from fastapi import FastAPI, HTTPException
  5from pydantic import BaseModel
  6from typing import Dict, Any
  7import uvicorn
  8from contextlib import asynccontextmanager
  9import threading
 10import argparse
 11from typing import List
 12from fastapi.logger import logger
 13from logging import StreamHandler, Formatter
 14import logging
 15
 16
 17server_lock = threading.Lock()
 18llama_server = None
 19
 20class ServerConfig(BaseModel):
 21    model_name: str
 22    configs: Dict[str, Any]
 23    server_path: List[str] = ["./llama-server"]
 24
 25class SwitchRequest(BaseModel):
 26    new_model_name: str
 27    new_configs: Dict[str, Any]
 28
 29class LlamaServer:
 30    def __init__(self, model_name: str, configs: Dict[str, Any], server_path: List[str]):
 31        self.server_path = server_path
 32        self.model_name = model_name
 33        self.configs = configs.copy()
 34        self.process = None
 35
 36    def _convert_configs_to_args(self) -> list:
 37        args = []
 38        for key, value in self.configs.items():
 39            if key.startswith("-"):
 40                arg_name = key
 41            else:
 42                arg_name = f"--{key.replace('_', '-')}"
 43            
 44            if isinstance(value, bool):
 45                if value:
 46                    args.append(arg_name)
 47            elif value == '':
 48                args.extend([arg_name])
 49            else:
 50                args.extend([arg_name, str(value)])
 51        return args
 52
 53    def start(self, log_file: str = "llama_server.log"):
 54        if self.process and self.process.poll() is None:
 55            raise RuntimeError("Server is already running")
 56        cmd = (self.server_path).copy()
 57        if app.state.llama_cpp_or_vllm == "llama_cpp":
 58            cmd.extend(["--model", self.model_name])
 59        elif app.state.llama_cpp_or_vllm == "vllm":
 60            cmd.extend([self.model_name])
 61        cmd.extend(self._convert_configs_to_args())
 62        print(cmd)
 63
 64        try:
 65            with open(log_file, "w") as log_f:
 66                self.process = subprocess.Popen(
 67                    cmd,
 68                    stdout=log_f,
 69                    stderr=log_f,
 70                    text=True,
 71                )
 72            time.sleep(1)
 73            if self.process.poll() is not None:
 74                raise RuntimeError(f"Server failed to start. Exit code: {self.process.returncode}")
 75        except Exception as e:
 76            self.process = None
 77            raise RuntimeError(f"Error starting server: {str(e)}")
 78
 79    def stop(self):
 80        if self.process:
 81            try:
 82                self.process.terminate()
 83                self.process.wait(timeout=5)
 84            except subprocess.TimeoutExpired:
 85                self.process.kill()
 86            finally:
 87                self.process = None
 88
 89    def switch(self, new_model_name: str, new_configs: Dict[str, Any]):
 90        self.stop()
 91        self.model_name = new_model_name
 92        self.configs = new_configs.copy()
 93        self.start(app.state.log_file)
 94
 95    def is_running(self) -> bool:
 96        return self.process is not None and self.process.poll() is None
 97
 98@asynccontextmanager
 99async def lifespan(app: FastAPI):
100    logger.info(f"Run in {app.state.llama_cpp_or_vllm} mode...")
101    yield
102    logger.info("Stopping server...")
103    global llama_server
104    if llama_server and llama_server.is_running():
105        llama_server.stop()
106
107app = FastAPI(lifespan=lifespan)
108
109@app.post("/wrapper_start")
110def start_server(config: ServerConfig):
111    global llama_server
112    with server_lock:
113        if llama_server and llama_server.is_running():
114            raise HTTPException(status_code=400, detail="Server is already running")
115        
116        try:
117            llama_server = LlamaServer(
118                model_name=config.model_name,
119                configs=config.configs,
120                server_path=config.server_path
121            )
122            llama_server.start(app.state.log_file)
123            return {
124                "status": "success",
125                "message": f"Server started with model: {config.model_name}",
126                "model": config.model_name,
127                "config": config.configs
128            }
129        except Exception as e:
130            raise HTTPException(status_code=500, detail=str(e))
131
132@app.post("/wrapper_switch")
133def switch_model(request: SwitchRequest):
134    global llama_server
135    with server_lock:
136        if not llama_server or not llama_server.is_running():
137            raise HTTPException(status_code=400, detail="Server is not running")
138        
139        try:
140            llama_server.switch(
141                new_model_name=request.new_model_name,
142                new_configs=request.new_configs
143            )
144            return {
145                "status": "success",
146                "message": f"Switched to model: {request.new_model_name}",
147                "new_model": request.new_model_name,
148                "new_config": request.new_configs
149            }
150        except Exception as e:
151            raise HTTPException(status_code=500, detail=str(e))
152
153@app.post("/wrapper_stop")
154def stop_server():
155    global llama_server
156    with server_lock:
157        if not llama_server or not llama_server.is_running():
158            raise HTTPException(status_code=400, detail="Server is not running")
159        
160        try:
161            llama_server.stop()
162            return {"status": "success", "message": "Server stopped"}
163        except Exception as e:
164            raise HTTPException(status_code=500, detail=str(e))
165
166@app.get("/wrapper_status")
167def get_status():
168    global llama_server
169    status = {
170        "is_running": False,
171        "current_model": None,
172        "current_config": None
173    }
174    
175    if llama_server and llama_server.is_running():
176        status.update({
177            "is_running": True,
178            "current_model": llama_server.model_name,
179            "current_config": llama_server.configs
180        })
181    
182    return status
183
184def main():
185    parser = argparse.ArgumentParser()
186    parser.add_argument("--host", type=str, default="0.0.0.0", help="Server listening host")
187    parser.add_argument("--port", type=int, default=8001, help="Server listening port")
188    parser.add_argument("--log_file", type=str, default="./server.log", help="Log file path")
189    parser.add_argument("--llama_cpp_or_vllm", "-lv", default="vllm", help="using `vllm` or `llama_cpp` as backend")
190    args = parser.parse_args()
191    args.llama_cpp_or_vllm in ["vllm", "llama_cpp"]
192    app.state.log_file = args.log_file
193    app.state.llama_cpp_or_vllm = args.llama_cpp_or_vllm
194    handler = StreamHandler()
195    formater = Formatter("%(levelname)s:     %(message)s")
196    handler.setFormatter(formater)
197    logger.setLevel(logging.INFO)
198    logger.addHandler(handler)
199    uvicorn.run(app, host="0.0.0.0", port=args.port)
200
201if __name__ == "__main__":
202    main()
server_lock = <unlocked _thread.lock object>
llama_server = None
class ServerConfig(pydantic.main.BaseModel):
21class ServerConfig(BaseModel):
22    model_name: str
23    configs: Dict[str, Any]
24    server_path: List[str] = ["./llama-server"]

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes: __class_vars__: The names of the class variables defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.

__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
    This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
    __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.

__pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.

__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
    is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
model_name: str
configs: Dict[str, Any]
server_path: List[str]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class SwitchRequest(pydantic.main.BaseModel):
26class SwitchRequest(BaseModel):
27    new_model_name: str
28    new_configs: Dict[str, Any]

Usage docs: https://docs.pydantic.dev/2.10/concepts/models/

A base class for creating Pydantic models.

Attributes: __class_vars__: The names of the class variables defined on the model. __private_attributes__: Metadata about the private attributes of the model. __signature__: The synthesized __init__ [Signature][inspect.Signature] of the model.

__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
    This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
    __args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.

__pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.

__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
    is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
new_model_name: str
new_configs: Dict[str, Any]
model_config: ClassVar[pydantic.config.ConfigDict] = {}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

class LlamaServer:
30class LlamaServer:
31    def __init__(self, model_name: str, configs: Dict[str, Any], server_path: List[str]):
32        self.server_path = server_path
33        self.model_name = model_name
34        self.configs = configs.copy()
35        self.process = None
36
37    def _convert_configs_to_args(self) -> list:
38        args = []
39        for key, value in self.configs.items():
40            if key.startswith("-"):
41                arg_name = key
42            else:
43                arg_name = f"--{key.replace('_', '-')}"
44            
45            if isinstance(value, bool):
46                if value:
47                    args.append(arg_name)
48            elif value == '':
49                args.extend([arg_name])
50            else:
51                args.extend([arg_name, str(value)])
52        return args
53
54    def start(self, log_file: str = "llama_server.log"):
55        if self.process and self.process.poll() is None:
56            raise RuntimeError("Server is already running")
57        cmd = (self.server_path).copy()
58        if app.state.llama_cpp_or_vllm == "llama_cpp":
59            cmd.extend(["--model", self.model_name])
60        elif app.state.llama_cpp_or_vllm == "vllm":
61            cmd.extend([self.model_name])
62        cmd.extend(self._convert_configs_to_args())
63        print(cmd)
64
65        try:
66            with open(log_file, "w") as log_f:
67                self.process = subprocess.Popen(
68                    cmd,
69                    stdout=log_f,
70                    stderr=log_f,
71                    text=True,
72                )
73            time.sleep(1)
74            if self.process.poll() is not None:
75                raise RuntimeError(f"Server failed to start. Exit code: {self.process.returncode}")
76        except Exception as e:
77            self.process = None
78            raise RuntimeError(f"Error starting server: {str(e)}")
79
80    def stop(self):
81        if self.process:
82            try:
83                self.process.terminate()
84                self.process.wait(timeout=5)
85            except subprocess.TimeoutExpired:
86                self.process.kill()
87            finally:
88                self.process = None
89
90    def switch(self, new_model_name: str, new_configs: Dict[str, Any]):
91        self.stop()
92        self.model_name = new_model_name
93        self.configs = new_configs.copy()
94        self.start(app.state.log_file)
95
96    def is_running(self) -> bool:
97        return self.process is not None and self.process.poll() is None
LlamaServer(model_name: str, configs: Dict[str, Any], server_path: List[str])
31    def __init__(self, model_name: str, configs: Dict[str, Any], server_path: List[str]):
32        self.server_path = server_path
33        self.model_name = model_name
34        self.configs = configs.copy()
35        self.process = None
server_path
model_name
configs
process
def start(self, log_file: str = 'llama_server.log'):
54    def start(self, log_file: str = "llama_server.log"):
55        if self.process and self.process.poll() is None:
56            raise RuntimeError("Server is already running")
57        cmd = (self.server_path).copy()
58        if app.state.llama_cpp_or_vllm == "llama_cpp":
59            cmd.extend(["--model", self.model_name])
60        elif app.state.llama_cpp_or_vllm == "vllm":
61            cmd.extend([self.model_name])
62        cmd.extend(self._convert_configs_to_args())
63        print(cmd)
64
65        try:
66            with open(log_file, "w") as log_f:
67                self.process = subprocess.Popen(
68                    cmd,
69                    stdout=log_f,
70                    stderr=log_f,
71                    text=True,
72                )
73            time.sleep(1)
74            if self.process.poll() is not None:
75                raise RuntimeError(f"Server failed to start. Exit code: {self.process.returncode}")
76        except Exception as e:
77            self.process = None
78            raise RuntimeError(f"Error starting server: {str(e)}")
def stop(self):
80    def stop(self):
81        if self.process:
82            try:
83                self.process.terminate()
84                self.process.wait(timeout=5)
85            except subprocess.TimeoutExpired:
86                self.process.kill()
87            finally:
88                self.process = None
def switch(self, new_model_name: str, new_configs: Dict[str, Any]):
90    def switch(self, new_model_name: str, new_configs: Dict[str, Any]):
91        self.stop()
92        self.model_name = new_model_name
93        self.configs = new_configs.copy()
94        self.start(app.state.log_file)
def is_running(self) -> bool:
96    def is_running(self) -> bool:
97        return self.process is not None and self.process.poll() is None
@asynccontextmanager
async def lifespan(app: fastapi.applications.FastAPI):
 99@asynccontextmanager
100async def lifespan(app: FastAPI):
101    logger.info(f"Run in {app.state.llama_cpp_or_vllm} mode...")
102    yield
103    logger.info("Stopping server...")
104    global llama_server
105    if llama_server and llama_server.is_running():
106        llama_server.stop()
app = <fastapi.applications.FastAPI object>
@app.post('/wrapper_start')
def start_server(config: ServerConfig):
110@app.post("/wrapper_start")
111def start_server(config: ServerConfig):
112    global llama_server
113    with server_lock:
114        if llama_server and llama_server.is_running():
115            raise HTTPException(status_code=400, detail="Server is already running")
116        
117        try:
118            llama_server = LlamaServer(
119                model_name=config.model_name,
120                configs=config.configs,
121                server_path=config.server_path
122            )
123            llama_server.start(app.state.log_file)
124            return {
125                "status": "success",
126                "message": f"Server started with model: {config.model_name}",
127                "model": config.model_name,
128                "config": config.configs
129            }
130        except Exception as e:
131            raise HTTPException(status_code=500, detail=str(e))
@app.post('/wrapper_switch')
def switch_model(request: SwitchRequest):
133@app.post("/wrapper_switch")
134def switch_model(request: SwitchRequest):
135    global llama_server
136    with server_lock:
137        if not llama_server or not llama_server.is_running():
138            raise HTTPException(status_code=400, detail="Server is not running")
139        
140        try:
141            llama_server.switch(
142                new_model_name=request.new_model_name,
143                new_configs=request.new_configs
144            )
145            return {
146                "status": "success",
147                "message": f"Switched to model: {request.new_model_name}",
148                "new_model": request.new_model_name,
149                "new_config": request.new_configs
150            }
151        except Exception as e:
152            raise HTTPException(status_code=500, detail=str(e))
@app.post('/wrapper_stop')
def stop_server():
154@app.post("/wrapper_stop")
155def stop_server():
156    global llama_server
157    with server_lock:
158        if not llama_server or not llama_server.is_running():
159            raise HTTPException(status_code=400, detail="Server is not running")
160        
161        try:
162            llama_server.stop()
163            return {"status": "success", "message": "Server stopped"}
164        except Exception as e:
165            raise HTTPException(status_code=500, detail=str(e))
@app.get('/wrapper_status')
def get_status():
167@app.get("/wrapper_status")
168def get_status():
169    global llama_server
170    status = {
171        "is_running": False,
172        "current_model": None,
173        "current_config": None
174    }
175    
176    if llama_server and llama_server.is_running():
177        status.update({
178            "is_running": True,
179            "current_model": llama_server.model_name,
180            "current_config": llama_server.configs
181        })
182    
183    return status
def main():
185def main():
186    parser = argparse.ArgumentParser()
187    parser.add_argument("--host", type=str, default="0.0.0.0", help="Server listening host")
188    parser.add_argument("--port", type=int, default=8001, help="Server listening port")
189    parser.add_argument("--log_file", type=str, default="./server.log", help="Log file path")
190    parser.add_argument("--llama_cpp_or_vllm", "-lv", default="vllm", help="using `vllm` or `llama_cpp` as backend")
191    args = parser.parse_args()
192    args.llama_cpp_or_vllm in ["vllm", "llama_cpp"]
193    app.state.log_file = args.log_file
194    app.state.llama_cpp_or_vllm = args.llama_cpp_or_vllm
195    handler = StreamHandler()
196    formater = Formatter("%(levelname)s:     %(message)s")
197    handler.setFormatter(formater)
198    logger.setLevel(logging.INFO)
199    logger.addHandler(handler)
200    uvicorn.run(app, host="0.0.0.0", port=args.port)