Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 30 additions & 31 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4432,6 +4432,17 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
SRV_DBG("response: %s\n", res.body.c_str());
}

static void res_error(httplib::Response & res, const json & error_data) {
json final_response {{"error", error_data}};
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500);
}

static void res_ok(httplib::Response & res, const json & data) {
res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
res.status = 200;
}

std::function<void(int)> shutdown_handler;
std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;

Expand Down Expand Up @@ -4501,19 +4512,7 @@ int main(int argc, char ** argv) {

svr->set_default_headers({{"Server", "llama.cpp"}});
svr->set_logger(log_server_request);

auto res_error = [](httplib::Response & res, const json & error_data) {
json final_response {{"error", error_data}};
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500);
};

auto res_ok = [](httplib::Response & res, const json & data) {
res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
res.status = 200;
};

svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
std::string message;
try {
std::rethrow_exception(ep);
Expand All @@ -4532,7 +4531,7 @@ int main(int argc, char ** argv) {
}
});

svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
svr->set_error_handler([](const httplib::Request &, httplib::Response & res) {
if (res.status == 404) {
res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
}
Expand Down Expand Up @@ -4562,7 +4561,7 @@ int main(int argc, char ** argv) {
// Middlewares
//

auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
auto middleware_validate_api_key = [&params](const httplib::Request & req, httplib::Response & res) {
static const std::unordered_set<std::string> public_endpoints = {
"/health",
"/v1/health",
Expand Down Expand Up @@ -4600,7 +4599,7 @@ int main(int argc, char ** argv) {
return false;
};

auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) {
server_state current_state = state.load();
if (current_state == SERVER_STATE_LOADING_MODEL) {
auto tmp = string_split<std::string>(req.path, '.');
Expand Down Expand Up @@ -4788,7 +4787,7 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};

const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
const auto handle_slots_save = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
Expand Down Expand Up @@ -4820,7 +4819,7 @@ int main(int argc, char ** argv) {
res_ok(res, result->to_json());
};

const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
const auto handle_slots_restore = [&ctx_server, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
json request_data = json::parse(req.body);
std::string filename = request_data.at("filename");
if (!fs_validate_filename(filename)) {
Expand Down Expand Up @@ -4853,7 +4852,7 @@ int main(int argc, char ** argv) {
res_ok(res, result->to_json());
};

const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
int task_id = ctx_server.queue_tasks.get_new_id();
{
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
Expand All @@ -4876,7 +4875,7 @@ int main(int argc, char ** argv) {
res_ok(res, result->to_json());
};

const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
const auto handle_slots_action = [&params, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
if (params.slot_save_path.empty()) {
res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
return;
Expand Down Expand Up @@ -4905,7 +4904,7 @@ int main(int argc, char ** argv) {
}
};

const auto handle_props = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
const auto handle_props = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
json default_generation_settings_for_props;

{
Expand Down Expand Up @@ -4947,7 +4946,7 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};

const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.endpoint_props) {
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
return;
Expand All @@ -4960,7 +4959,7 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "success", true }});
};

const auto handle_api_show = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
const auto handle_api_show = [&ctx_server](const httplib::Request &, httplib::Response & res) {
bool has_mtmd = ctx_server.mctx != nullptr;
json data = {
{
Expand Down Expand Up @@ -4991,7 +4990,7 @@ int main(int argc, char ** argv) {

// handle completion-like requests (completion, chat, infill)
// we can optionally provide a custom format for partial results and final results
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
const auto handle_completions_impl = [&ctx_server](
server_task_type type,
json & data,
const std::vector<raw_buffer> & files,
Expand Down Expand Up @@ -5139,7 +5138,7 @@ int main(int argc, char ** argv) {
OAICOMPAT_TYPE_COMPLETION);
};

const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
// check model compatibility
std::string err;
if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
Expand Down Expand Up @@ -5238,7 +5237,7 @@ int main(int argc, char ** argv) {
};

// same with handle_chat_completions, but without inference part
const auto handle_apply_template = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_apply_template = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
auto body = json::parse(req.body);
std::vector<raw_buffer> files; // dummy, unused
json data = oaicompat_chat_params_parse(
Expand All @@ -5248,7 +5247,7 @@ int main(int argc, char ** argv) {
res_ok(res, {{ "prompt", std::move(data.at("prompt")) }});
};

const auto handle_models = [&params, &ctx_server, &state, &res_ok](const httplib::Request &, httplib::Response & res) {
const auto handle_models = [&params, &ctx_server, &state](const httplib::Request &, httplib::Response & res) {
server_state current_state = state.load();
json model_meta = nullptr;
if (current_state == SERVER_STATE_READY) {
Expand Down Expand Up @@ -5293,7 +5292,7 @@ int main(int argc, char ** argv) {
res_ok(res, models);
};

const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);

json tokens_response = json::array();
Expand Down Expand Up @@ -5334,7 +5333,7 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};

const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
const json body = json::parse(req.body);

std::string content;
Expand All @@ -5347,7 +5346,7 @@ int main(int argc, char ** argv) {
res_ok(res, data);
};

const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
if (!ctx_server.params_base.embedding) {
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return;
Expand Down Expand Up @@ -5457,7 +5456,7 @@ int main(int argc, char ** argv) {
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
};

const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) {
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
return;
Expand Down
Loading