@@ -4319,16 +4319,15 @@ struct server_context {
43194319 }
43204320};
43214321
4322- // generator-like API for server responses
4323- struct server_response_generator {
4322+ // generator-like API for server responses, support pooling connection state and aggregating results
4323+ struct server_response_reader {
43244324 std::unordered_set<int > id_tasks;
43254325 server_context & ctx_server;
43264326 size_t received_count = 0 ;
43274327 bool cancelled = false ;
43284328
4329- server_response_generator (server_context & ctx_server) : ctx_server(ctx_server) {}
4330- ~server_response_generator () {
4331- SRV_DBG (" %s" , " deleting server_response_generator\n " );
4329+ server_response_reader (server_context & ctx_server) : ctx_server(ctx_server) {}
4330+ ~server_response_reader () {
43324331 stop ();
43334332 }
43344333
@@ -5000,9 +4999,9 @@ int main(int argc, char ** argv) {
50004999 GGML_ASSERT (type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
50015000
50025001 auto completion_id = gen_chatcmplid ();
5003- // need to store the generator as a pointer, so that it won't be destroyed when the handle returns
5002+ // need to store the reader as a pointer, so that it won't be destroyed when the handle returns
50045003 // use shared_ptr as it's shared between the chunked_content_provider() and on_complete()
5005- const auto gen = std::make_shared<server_response_generator >(ctx_server);
5004+ const auto rd = std::make_shared<server_response_reader >(ctx_server);
50065005
50075006 try {
50085007 std::vector<server_task> tasks;
@@ -5043,7 +5042,7 @@ int main(int argc, char ** argv) {
50435042 tasks.push_back (std::move (task));
50445043 }
50455044
5046- gen ->post_tasks (std::move (tasks));
5045+ rd ->post_tasks (std::move (tasks));
50475046 } catch (const std::exception & e) {
50485047 res_error (res, format_error_response (e.what (), ERROR_TYPE_INVALID_REQUEST));
50495048 return ;
@@ -5053,7 +5052,7 @@ int main(int argc, char ** argv) {
50535052
50545053 if (!stream) {
50555054 // non-stream, wait for the results
5056- auto all_results = gen ->wait_for_all (is_connection_closed);
5055+ auto all_results = rd ->wait_for_all (is_connection_closed);
50575056 if (all_results.is_terminated ) {
50585057 return ; // connection is closed
50595058 } else if (all_results.error ) {
@@ -5073,7 +5072,7 @@ int main(int argc, char ** argv) {
50735072 // in streaming mode, the first error must be treated as non-stream response
50745073 // this is to match the OAI API behavior
50755074 // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
5076- server_task_result_ptr first_result = gen ->next (is_connection_closed);
5075+ server_task_result_ptr first_result = rd ->next (is_connection_closed);
50775076 if (first_result == nullptr ) {
50785077 return ; // connection is closed
50795078 } else if (first_result->is_error ()) {
@@ -5088,7 +5087,7 @@ int main(int argc, char ** argv) {
50885087
50895088 // next responses are streamed
50905089 json first_result_json = first_result->to_json ();
5091- const auto chunked_content_provider = [first_result_json, gen , oaicompat](size_t , httplib::DataSink & sink) mutable -> bool {
5090+ const auto chunked_content_provider = [first_result_json, rd , oaicompat](size_t , httplib::DataSink & sink) mutable -> bool {
50925091 // flush the first result as it's not an error
50935092 if (!first_result_json.empty ()) {
50945093 if (!server_sent_event (sink, first_result_json)) {
@@ -5099,7 +5098,7 @@ int main(int argc, char ** argv) {
50995098 }
51005099
51015100 // receive subsequent results
5102- auto result = gen ->next ([&sink]{ return !sink.is_writable (); });
5101+ auto result = rd ->next ([&sink]{ return !sink.is_writable (); });
51035102 if (result == nullptr ) {
51045103 sink.done ();
51055104 return false ; // connection is closed, go to on_complete()
@@ -5126,7 +5125,7 @@ int main(int argc, char ** argv) {
51265125 }
51275126
51285127 // check if there is more data
5129- if (!gen ->has_next ()) {
5128+ if (!rd ->has_next ()) {
51305129 if (oaicompat != OAICOMPAT_TYPE_NONE) {
51315130 static const std::string ev_done = " data: [DONE]\n\n " ;
51325131 sink.write (ev_done.data (), ev_done.size ());
@@ -5139,8 +5138,8 @@ int main(int argc, char ** argv) {
51395138 return true ;
51405139 };
51415140
5142- auto on_complete = [gen ](bool ) {
5143- gen ->stop ();
5141+ auto on_complete = [rd ](bool ) {
5142+ rd ->stop ();
51445143 };
51455144
51465145 res.set_chunked_content_provider (" text/event-stream" , chunked_content_provider, on_complete);
@@ -5434,7 +5433,7 @@ int main(int argc, char ** argv) {
54345433
54355434 // create and queue the task
54365435 json responses = json::array ();
5437- server_response_generator gen (ctx_server);
5436+ server_response_reader rd (ctx_server);
54385437 {
54395438 std::vector<server_task> tasks;
54405439 for (size_t i = 0 ; i < tokenized_prompts.size (); i++) {
@@ -5450,11 +5449,11 @@ int main(int argc, char ** argv) {
54505449
54515450 tasks.push_back (std::move (task));
54525451 }
5453- gen .post_tasks (std::move (tasks));
5452+ rd .post_tasks (std::move (tasks));
54545453 }
54555454
54565455 // wait for the results
5457- auto all_results = gen .wait_for_all (req.is_connection_closed );
5456+ auto all_results = rd .wait_for_all (req.is_connection_closed );
54585457
54595458 // collect results
54605459 if (all_results.is_terminated ) {
@@ -5520,7 +5519,7 @@ int main(int argc, char ** argv) {
55205519
55215520 // create and queue the task
55225521 json responses = json::array ();
5523- server_response_generator gen (ctx_server);
5522+ server_response_reader rd (ctx_server);
55245523 {
55255524 std::vector<server_task> tasks;
55265525 tasks.reserve (documents.size ());
@@ -5532,11 +5531,11 @@ int main(int argc, char ** argv) {
55325531 task.tokens = std::move (tmp);
55335532 tasks.push_back (std::move (task));
55345533 }
5535- gen .post_tasks (std::move (tasks));
5534+ rd .post_tasks (std::move (tasks));
55365535 }
55375536
55385537 // wait for the results
5539- auto all_results = gen .wait_for_all (req.is_connection_closed );
5538+ auto all_results = rd .wait_for_all (req.is_connection_closed );
55405539
55415540 // collect results
55425541 if (all_results.is_terminated ) {
0 commit comments