Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions tools/server/server-http.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,3 +383,88 @@ void server_http_context::post(const std::string & path, server_http_context::ha
});
}


//
// server_http_client
//

server_http_client::server_http_client(
const std::string & method,
const std::string & host,
int port,
const std::string & path,
const std::map<std::string, std::string> & headers,
const std::string & body,
const std::function<bool()> should_stop) {
// shared between reader and writer threads
auto cli = std::make_shared<httplib::Client>(host, port);
auto pipe = std::make_shared<pipe_t<msg_t>>();

// setup Client
cli->set_connection_timeout(0, 200000); // 200 milliseconds
this->status = 500; // to be overwritten upon response
this->cleanup = [pipe]() {
pipe->close_read();
pipe->close_write();
};

// wire up the receive end of the pipe
this->next = [pipe, should_stop](std::string & out) -> bool {
msg_t msg;
bool has_next = pipe->read(msg, should_stop);
if (!msg.data.empty()) {
out = std::move(msg.data);
}
return has_next;
};

// wire up the HTTP client
// note: do NOT capture `this` pointer, as it may be destroyed before the thread ends
httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) {
msg_t msg;
msg.status = response.status;
for (const auto & [key, value] : response.headers) {
msg.headers[key] = value;
}
pipe->write(std::move(msg)); // send headers first
return true;
};
httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) {
return pipe->write({{}, 0, std::string(data, data_length)}); // send data chunks
};

// prepare the request to destination server
httplib::Request req;
{
req.method = method;
req.path = path;
for (const auto & [key, value] : headers) {
req.set_header(key, value);
}
req.body = body;
req.response_handler = response_handler;
req.content_receiver = content_receiver;
}

// start the proxy thread
SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str());
this->thread = std::thread([cli, pipe, req]() {
auto result = cli->send(std::move(req));
if (result.error() != httplib::Error::Success) {
auto err_str = httplib::to_string(result.error());
SRV_ERR("http client error: %s\n", err_str.c_str());
pipe->write({{}, 500, ""}); // header
pipe->write({{}, 0, "proxy error: " + err_str}); // body
}
pipe->close_write(); // signal EOF to reader
SRV_DBG("%s", "client request thread ended\n");
});
this->thread.detach();

// wait for the first chunk (headers)
msg_t header;
pipe->read(header, should_stop);
SRV_DBG("%s", "received response headers\n");
this->status = header.status;
this->headers = header.headers;
}
76 changes: 76 additions & 0 deletions tools/server/server-http.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,79 @@ struct server_http_context {
// for debugging
std::string listening_address;
};



#include <queue>
#include <mutex>
#include <mutex>
#include <condition_variable>

struct server_http_client : server_http_res {
std::function<void()> cleanup = nullptr;
public:
server_http_client(const std::string & method,
const std::string & host,
int port,
const std::string & path,
const std::map<std::string, std::string> & headers,
const std::string & body,
const std::function<bool()> should_stop);
~server_http_client() {
if (cleanup) {
cleanup();
}
}
private:
std::thread thread;
struct msg_t {
std::map<std::string, std::string> headers;
int status = 0;
std::string data;
};
// simple implementation of a pipe
template<typename T>
struct pipe_t {
std::mutex mutex;
std::condition_variable cv;
std::queue<T> queue;
std::atomic<bool> writer_closed{false};
std::atomic<bool> reader_closed{false};
void close_write() {
writer_closed.store(true);
cv.notify_all();
}
void close_read() {
reader_closed.store(true);
cv.notify_all();
}
bool read(T & output, const std::function<bool()> & should_stop) {
std::unique_lock<std::mutex> lk(mutex);
constexpr auto poll_interval = std::chrono::milliseconds(500);
while (true) {
if (!queue.empty()) {
output = std::move(queue.front());
queue.pop();
return true;
}
if (writer_closed.load()) {
return false; // clean EOF
}
if (should_stop()) {
close_read(); // signal broken pipe to writer
return false; // cancelled / reader no longer alive
}
cv.wait_for(lk, poll_interval);
}
}
bool write(T && data) {
std::lock_guard<std::mutex> lk(mutex);
if (reader_closed.load()) {
return false; // broken pipe
}
queue.push(std::move(data));
cv.notify_one();
return true;
}
};
};
171 changes: 170 additions & 1 deletion tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5108,6 +5108,109 @@ struct server_routes {
return res;
};

//
// router server
//
server_instances instances;
server_http_context::handler_t proxy_get = [this](const server_http_req & req) {
std::string method = "GET";
std::string model = req.get_param("model");
if (req.path == "/props" && model.empty()) {
return handle_default_props(req);
}
instances.ensure_model_loaded(model);
return handle_proxy(req, method, model);
};
server_http_context::handler_t proxy_post = [this](const server_http_req & req) {
std::string method = "POST";
json body = json::parse(req.body);
std::string model = json_value(body, "model", std::string());
instances.ensure_model_loaded(model);
return handle_proxy(req, method, model);
};
server_http_res_ptr handle_proxy(const server_http_req & req, std::string & method, std::string model) {
auto meta = instances.get_meta(model);
if (!meta.has_value()) {
auto res = std::make_unique<server_res_generator>(ctx_server);
res->error(format_error_response("model is unavailable", ERROR_TYPE_UNAVAILABLE));
return res;
}
server_http_res_ptr res(new server_http_client(
method, params.hostname, meta->port,
req.path, req.headers, req.body, req.should_stop
));
return res;
}
server_http_res_ptr handle_default_props(const server_http_req &) {
auto res = std::make_unique<server_res_generator>(ctx_server);
// this is a dummy response to make sure webui doesn't break
res->ok({
{"model_alias", "llama-server"},
{"model_path", "none"},
{"default_generation_settings", {
{"params", json{}},
{"n_ctx", 0},
}},
});
return res;
}
server_http_context::handler_t post_router_models_load = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json body = json::parse(req.body);
std::string model = json_value(body, "model", std::string());
instances.create(model);
res->ok({{"success", true}});
return res;
};
server_http_context::handler_t post_router_models_status = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json body = json::parse(req.body);
std::string model = json_value(body, "model", std::string());
std::string value = json_value(body, "value", std::string());
if (!instances.get_meta(model).has_value()) {
auto res = std::make_unique<server_res_generator>(ctx_server);
res->error(format_error_response("model is unavailable", ERROR_TYPE_UNAVAILABLE));
return res;
}
instances.update_status(model, value);
res->ok({{"success", true}});
return res;
};
server_http_context::handler_t get_router_models = [this](const server_http_req &) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json models_json = json::array();
auto models = common_list_cached_models();
for (const auto & model : models) {
auto model_name = model.to_string();
auto meta = instances.get_meta(model_name);
bool found = meta.has_value();
models_json.push_back(json {
{"model", model_name},
{"name", model_name},
{"id", model_name},
// TODO: other fields...
{"status", {
{"value", found ? meta->status : "unloaded"}
}},
});
}
res->ok({{"data", models_json}});
return res;
};
server_http_context::handler_t post_router_models_unload = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
json body = json::parse(req.body);
std::string model = json_value(body, "model", std::string());
if (!instances.get_meta(model).has_value()) {
auto res = std::make_unique<server_res_generator>(ctx_server);
res->error(format_error_response("model is unavailable", ERROR_TYPE_UNAVAILABLE));
return res;
}
instances.kill_single(model);
res->ok({{"success", true}});
return res;
};

private:
std::unique_ptr<server_res_generator> handle_completions_impl(
server_task_type type,
Expand Down Expand Up @@ -5501,7 +5604,7 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t
};
}

int main(int argc, char ** argv) {
int main(int argc, char ** argv, char ** envp) {
// own arguments required by this example
common_params params;

Expand Down Expand Up @@ -5549,6 +5652,38 @@ int main(int argc, char ** argv) {
// register API routes
server_routes routes(params, ctx_server, ctx_http);

// hacky, replace handlers with proxy handlers if this is a router server
bool is_router_server = params.model.path == DEFAULT_MODEL_PATH;
if (is_router_server) {
// setup server instances manager
routes.instances.envp = envp;
routes.instances.router_port = params.port;

// proxy handlers
routes.get_props = routes.proxy_get;
routes.post_props = routes.proxy_post;
routes.post_completions = routes.proxy_post;
routes.post_completions_oai = routes.proxy_post;
routes.post_chat_completions = routes.proxy_post;
routes.post_infill = routes.proxy_post;
routes.post_embeddings = routes.proxy_post;
routes.post_embeddings_oai = routes.proxy_post;
routes.post_rerank = routes.proxy_post;
routes.post_tokenize = routes.proxy_post;
routes.post_detokenize = routes.proxy_post;
routes.post_apply_template = routes.proxy_post;
routes.get_lora_adapters = routes.proxy_get;
routes.post_lora_adapters = routes.proxy_post;
routes.get_slots = routes.proxy_get;
routes.post_slots = routes.proxy_post;

// custom routes for router
routes.get_models = routes.get_router_models;
ctx_http.post("/models/load", ex_wrapper(routes.post_router_models_load));
ctx_http.post("/models/unload", ex_wrapper(routes.post_router_models_unload));
ctx_http.post("/models/status", ex_wrapper(routes.post_router_models_status));
}

ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check)
ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics));
Expand Down Expand Up @@ -5594,6 +5729,8 @@ int main(int argc, char ** argv) {
llama_backend_free();
};

if (!is_router_server) { // HACKY

// start the HTTP server before loading the model to be able to serve /health requests
if (!ctx_http.start()) {
clean_up();
Expand Down Expand Up @@ -5631,6 +5768,8 @@ int main(int argc, char ** argv) {
ctx_server.queue_tasks.terminate();
};

} // end of !is_router_server

#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action;
sigint_action.sa_handler = signal_handler;
Expand All @@ -5645,6 +5784,23 @@ int main(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif

if (!is_router_server) { // HACKY

// notify to main router if needed
char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT");
if (router_port != nullptr) {
SRV_INF("%s: notifying to main router on port %s\n", __func__, router_port);
server_http_client notify_router(
"POST", params.hostname, std::atoi(router_port),
"/models/status",
{ {"Content-Type", "application/json"} },
json {{ "model", params.model_alias }, { "value", "loaded" }}.dump(),
[]() { return false; }
);
std::string dummy;
notify_router.next(dummy); // ignore the response
}

LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str());
LOG_INF("%s: starting the main loop...\n", __func__);
// this call blocks the main thread until queue_tasks.terminate() is called
Expand All @@ -5655,6 +5811,19 @@ int main(int argc, char ** argv) {
ctx_http.thread.join();
}
llama_memory_breakdown_print(ctx_server.ctx);
} else {
shutdown_handler = [&](int) {
ctx_http.stop();
};
if (!ctx_http.start()) {
LOG_ERR("%s: exiting due to HTTP server error\n", __func__);
return 1;
}
ctx_http.is_ready.store(true);
ctx_http.thread.join(); // keep the main thread alive
// kill_all_instances(routes.map_model_to_port); // why this also kill the main instance?
LOG_INF("%s: server stopped\n", __func__);
} // end of !is_router_server

return 0;
}
Loading
Loading