Kkit.llm_utils.llamacpp_wrapper_server
1import subprocess 2import json 3import time 4from fastapi import FastAPI, HTTPException 5from pydantic import BaseModel 6from typing import Dict, Any 7import uvicorn 8from contextlib import asynccontextmanager 9import threading 10import argparse 11from typing import List 12from fastapi.logger import logger 13from logging import StreamHandler, Formatter 14import logging 15 16 17server_lock = threading.Lock() 18llama_server = None 19 20class ServerConfig(BaseModel): 21 model_name: str 22 configs: Dict[str, Any] 23 server_path: List[str] = ["./llama-server"] 24 25class SwitchRequest(BaseModel): 26 new_model_name: str 27 new_configs: Dict[str, Any] 28 29class LlamaServer: 30 def __init__(self, model_name: str, configs: Dict[str, Any], server_path: List[str]): 31 self.server_path = server_path 32 self.model_name = model_name 33 self.configs = configs.copy() 34 self.process = None 35 36 def _convert_configs_to_args(self) -> list: 37 args = [] 38 for key, value in self.configs.items(): 39 if key.startswith("-"): 40 arg_name = key 41 else: 42 arg_name = f"--{key.replace('_', '-')}" 43 44 if isinstance(value, bool): 45 if value: 46 args.append(arg_name) 47 elif value == '': 48 args.extend([arg_name]) 49 else: 50 args.extend([arg_name, str(value)]) 51 return args 52 53 def start(self, log_file: str = "llama_server.log"): 54 if self.process and self.process.poll() is None: 55 raise RuntimeError("Server is already running") 56 cmd = (self.server_path).copy() 57 if app.state.llama_cpp_or_vllm == "llama_cpp": 58 cmd.extend(["--model", self.model_name]) 59 elif app.state.llama_cpp_or_vllm == "vllm": 60 cmd.extend([self.model_name]) 61 cmd.extend(self._convert_configs_to_args()) 62 print(cmd) 63 64 try: 65 with open(log_file, "w") as log_f: 66 self.process = subprocess.Popen( 67 cmd, 68 stdout=log_f, 69 stderr=log_f, 70 text=True, 71 ) 72 time.sleep(1) 73 if self.process.poll() is not None: 74 raise RuntimeError(f"Server failed to start. Exit code: {self.process.returncode}") 75 except Exception as e: 76 self.process = None 77 raise RuntimeError(f"Error starting server: {str(e)}") 78 79 def stop(self): 80 if self.process: 81 try: 82 self.process.terminate() 83 self.process.wait(timeout=5) 84 except subprocess.TimeoutExpired: 85 self.process.kill() 86 finally: 87 self.process = None 88 89 def switch(self, new_model_name: str, new_configs: Dict[str, Any]): 90 self.stop() 91 self.model_name = new_model_name 92 self.configs = new_configs.copy() 93 self.start(app.state.log_file) 94 95 def is_running(self) -> bool: 96 return self.process is not None and self.process.poll() is None 97 98@asynccontextmanager 99async def lifespan(app: FastAPI): 100 logger.info(f"Run in {app.state.llama_cpp_or_vllm} mode...") 101 yield 102 logger.info("Stopping server...") 103 global llama_server 104 if llama_server and llama_server.is_running(): 105 llama_server.stop() 106 107app = FastAPI(lifespan=lifespan) 108 109@app.post("/wrapper_start") 110def start_server(config: ServerConfig): 111 global llama_server 112 with server_lock: 113 if llama_server and llama_server.is_running(): 114 raise HTTPException(status_code=400, detail="Server is already running") 115 116 try: 117 llama_server = LlamaServer( 118 model_name=config.model_name, 119 configs=config.configs, 120 server_path=config.server_path 121 ) 122 llama_server.start(app.state.log_file) 123 return { 124 "status": "success", 125 "message": f"Server started with model: {config.model_name}", 126 "model": config.model_name, 127 "config": config.configs 128 } 129 except Exception as e: 130 raise HTTPException(status_code=500, detail=str(e)) 131 132@app.post("/wrapper_switch") 133def switch_model(request: SwitchRequest): 134 global llama_server 135 with server_lock: 136 if not llama_server or not llama_server.is_running(): 137 raise HTTPException(status_code=400, detail="Server is not running") 138 139 try: 140 llama_server.switch( 141 new_model_name=request.new_model_name, 142 new_configs=request.new_configs 143 ) 144 return { 145 "status": "success", 146 "message": f"Switched to model: {request.new_model_name}", 147 "new_model": request.new_model_name, 148 "new_config": request.new_configs 149 } 150 except Exception as e: 151 raise HTTPException(status_code=500, detail=str(e)) 152 153@app.post("/wrapper_stop") 154def stop_server(): 155 global llama_server 156 with server_lock: 157 if not llama_server or not llama_server.is_running(): 158 raise HTTPException(status_code=400, detail="Server is not running") 159 160 try: 161 llama_server.stop() 162 return {"status": "success", "message": "Server stopped"} 163 except Exception as e: 164 raise HTTPException(status_code=500, detail=str(e)) 165 166@app.get("/wrapper_status") 167def get_status(): 168 global llama_server 169 status = { 170 "is_running": False, 171 "current_model": None, 172 "current_config": None 173 } 174 175 if llama_server and llama_server.is_running(): 176 status.update({ 177 "is_running": True, 178 "current_model": llama_server.model_name, 179 "current_config": llama_server.configs 180 }) 181 182 return status 183 184def main(): 185 parser = argparse.ArgumentParser() 186 parser.add_argument("--host", type=str, default="0.0.0.0", help="Server listening host") 187 parser.add_argument("--port", type=int, default=8001, help="Server listening port") 188 parser.add_argument("--log_file", type=str, default="./server.log", help="Log file path") 189 parser.add_argument("--llama_cpp_or_vllm", "-lv", default="vllm", help="using `vllm` or `llama_cpp` as backend") 190 args = parser.parse_args() 191 args.llama_cpp_or_vllm in ["vllm", "llama_cpp"] 192 app.state.log_file = args.log_file 193 app.state.llama_cpp_or_vllm = args.llama_cpp_or_vllm 194 handler = StreamHandler() 195 formater = Formatter("%(levelname)s: %(message)s") 196 handler.setFormatter(formater) 197 logger.setLevel(logging.INFO) 198 logger.addHandler(handler) 199 uvicorn.run(app, host="0.0.0.0", port=args.port) 200 201if __name__ == "__main__": 202 main()
server_lock =
<unlocked _thread.lock object>
llama_server =
None
class
ServerConfig(pydantic.main.BaseModel):
21class ServerConfig(BaseModel): 22 model_name: str 23 configs: Dict[str, Any] 24 server_path: List[str] = ["./llama-server"]
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
__class_vars__: The names of the class variables defined on the model.
__private_attributes__: Metadata about the private attributes of the model.
__signature__: The synthesized __init__
[Signature
][inspect.Signature] of the model.
__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
__args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.
__pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
class
SwitchRequest(pydantic.main.BaseModel):
Usage docs: https://docs.pydantic.dev/2.10/concepts/models/
A base class for creating Pydantic models.
Attributes:
__class_vars__: The names of the class variables defined on the model.
__private_attributes__: Metadata about the private attributes of the model.
__signature__: The synthesized __init__
[Signature
][inspect.Signature] of the model.
__pydantic_complete__: Whether model building is completed, or if there are still undefined fields.
__pydantic_core_schema__: The core schema of the model.
__pydantic_custom_init__: Whether the model has a custom `__init__` function.
__pydantic_decorators__: Metadata containing the decorators defined on the model.
This replaces `Model.__validators__` and `Model.__root_validators__` from Pydantic V1.
__pydantic_generic_metadata__: Metadata for generic models; contains data used for a similar purpose to
__args__, __origin__, __parameters__ in typing-module generics. May eventually be replaced by these.
__pydantic_parent_namespace__: Parent namespace of the model, used for automatic rebuilding of models.
__pydantic_post_init__: The name of the post-init method for the model, if defined.
__pydantic_root_model__: Whether the model is a [`RootModel`][pydantic.root_model.RootModel].
__pydantic_serializer__: The `pydantic-core` `SchemaSerializer` used to dump instances of the model.
__pydantic_validator__: The `pydantic-core` `SchemaValidator` used to validate instances of the model.
__pydantic_fields__: A dictionary of field names and their corresponding [`FieldInfo`][pydantic.fields.FieldInfo] objects.
__pydantic_computed_fields__: A dictionary of computed field names and their corresponding [`ComputedFieldInfo`][pydantic.fields.ComputedFieldInfo] objects.
__pydantic_extra__: A dictionary containing extra values, if [`extra`][pydantic.config.ConfigDict.extra]
is set to `'allow'`.
__pydantic_fields_set__: The names of fields explicitly set during instantiation.
__pydantic_private__: Values of private attributes set on the model instance.
class
LlamaServer:
30class LlamaServer: 31 def __init__(self, model_name: str, configs: Dict[str, Any], server_path: List[str]): 32 self.server_path = server_path 33 self.model_name = model_name 34 self.configs = configs.copy() 35 self.process = None 36 37 def _convert_configs_to_args(self) -> list: 38 args = [] 39 for key, value in self.configs.items(): 40 if key.startswith("-"): 41 arg_name = key 42 else: 43 arg_name = f"--{key.replace('_', '-')}" 44 45 if isinstance(value, bool): 46 if value: 47 args.append(arg_name) 48 elif value == '': 49 args.extend([arg_name]) 50 else: 51 args.extend([arg_name, str(value)]) 52 return args 53 54 def start(self, log_file: str = "llama_server.log"): 55 if self.process and self.process.poll() is None: 56 raise RuntimeError("Server is already running") 57 cmd = (self.server_path).copy() 58 if app.state.llama_cpp_or_vllm == "llama_cpp": 59 cmd.extend(["--model", self.model_name]) 60 elif app.state.llama_cpp_or_vllm == "vllm": 61 cmd.extend([self.model_name]) 62 cmd.extend(self._convert_configs_to_args()) 63 print(cmd) 64 65 try: 66 with open(log_file, "w") as log_f: 67 self.process = subprocess.Popen( 68 cmd, 69 stdout=log_f, 70 stderr=log_f, 71 text=True, 72 ) 73 time.sleep(1) 74 if self.process.poll() is not None: 75 raise RuntimeError(f"Server failed to start. Exit code: {self.process.returncode}") 76 except Exception as e: 77 self.process = None 78 raise RuntimeError(f"Error starting server: {str(e)}") 79 80 def stop(self): 81 if self.process: 82 try: 83 self.process.terminate() 84 self.process.wait(timeout=5) 85 except subprocess.TimeoutExpired: 86 self.process.kill() 87 finally: 88 self.process = None 89 90 def switch(self, new_model_name: str, new_configs: Dict[str, Any]): 91 self.stop() 92 self.model_name = new_model_name 93 self.configs = new_configs.copy() 94 self.start(app.state.log_file) 95 96 def is_running(self) -> bool: 97 return self.process is not None and self.process.poll() is None
def
start(self, log_file: str = 'llama_server.log'):
54 def start(self, log_file: str = "llama_server.log"): 55 if self.process and self.process.poll() is None: 56 raise RuntimeError("Server is already running") 57 cmd = (self.server_path).copy() 58 if app.state.llama_cpp_or_vllm == "llama_cpp": 59 cmd.extend(["--model", self.model_name]) 60 elif app.state.llama_cpp_or_vllm == "vllm": 61 cmd.extend([self.model_name]) 62 cmd.extend(self._convert_configs_to_args()) 63 print(cmd) 64 65 try: 66 with open(log_file, "w") as log_f: 67 self.process = subprocess.Popen( 68 cmd, 69 stdout=log_f, 70 stderr=log_f, 71 text=True, 72 ) 73 time.sleep(1) 74 if self.process.poll() is not None: 75 raise RuntimeError(f"Server failed to start. Exit code: {self.process.returncode}") 76 except Exception as e: 77 self.process = None 78 raise RuntimeError(f"Error starting server: {str(e)}")
@asynccontextmanager
async def
lifespan(app: fastapi.applications.FastAPI):
app =
<fastapi.applications.FastAPI object>
110@app.post("/wrapper_start") 111def start_server(config: ServerConfig): 112 global llama_server 113 with server_lock: 114 if llama_server and llama_server.is_running(): 115 raise HTTPException(status_code=400, detail="Server is already running") 116 117 try: 118 llama_server = LlamaServer( 119 model_name=config.model_name, 120 configs=config.configs, 121 server_path=config.server_path 122 ) 123 llama_server.start(app.state.log_file) 124 return { 125 "status": "success", 126 "message": f"Server started with model: {config.model_name}", 127 "model": config.model_name, 128 "config": config.configs 129 } 130 except Exception as e: 131 raise HTTPException(status_code=500, detail=str(e))
133@app.post("/wrapper_switch") 134def switch_model(request: SwitchRequest): 135 global llama_server 136 with server_lock: 137 if not llama_server or not llama_server.is_running(): 138 raise HTTPException(status_code=400, detail="Server is not running") 139 140 try: 141 llama_server.switch( 142 new_model_name=request.new_model_name, 143 new_configs=request.new_configs 144 ) 145 return { 146 "status": "success", 147 "message": f"Switched to model: {request.new_model_name}", 148 "new_model": request.new_model_name, 149 "new_config": request.new_configs 150 } 151 except Exception as e: 152 raise HTTPException(status_code=500, detail=str(e))
@app.post('/wrapper_stop')
def
stop_server():
154@app.post("/wrapper_stop") 155def stop_server(): 156 global llama_server 157 with server_lock: 158 if not llama_server or not llama_server.is_running(): 159 raise HTTPException(status_code=400, detail="Server is not running") 160 161 try: 162 llama_server.stop() 163 return {"status": "success", "message": "Server stopped"} 164 except Exception as e: 165 raise HTTPException(status_code=500, detail=str(e))
@app.get('/wrapper_status')
def
get_status():
167@app.get("/wrapper_status") 168def get_status(): 169 global llama_server 170 status = { 171 "is_running": False, 172 "current_model": None, 173 "current_config": None 174 } 175 176 if llama_server and llama_server.is_running(): 177 status.update({ 178 "is_running": True, 179 "current_model": llama_server.model_name, 180 "current_config": llama_server.configs 181 }) 182 183 return status
def
main():
185def main(): 186 parser = argparse.ArgumentParser() 187 parser.add_argument("--host", type=str, default="0.0.0.0", help="Server listening host") 188 parser.add_argument("--port", type=int, default=8001, help="Server listening port") 189 parser.add_argument("--log_file", type=str, default="./server.log", help="Log file path") 190 parser.add_argument("--llama_cpp_or_vllm", "-lv", default="vllm", help="using `vllm` or `llama_cpp` as backend") 191 args = parser.parse_args() 192 args.llama_cpp_or_vllm in ["vllm", "llama_cpp"] 193 app.state.log_file = args.log_file 194 app.state.llama_cpp_or_vllm = args.llama_cpp_or_vllm 195 handler = StreamHandler() 196 formater = Formatter("%(levelname)s: %(message)s") 197 handler.setFormatter(formater) 198 logger.setLevel(logging.INFO) 199 logger.addHandler(handler) 200 uvicorn.run(app, host="0.0.0.0", port=args.port)