@@ -93,40 +93,25 @@ def wait_for_termination(self) -> None:
9393 self .stop ()
9494
9595
96- def run (
97- port : int ,
96+ def create_driver (
9897 config : Type [config_lib .ServerConfig ],
9998 devices : Any ,
100- credentials : Any = grpc .insecure_server_credentials (),
101- threads : int | None = None ,
10299 jax_padding : bool = True ,
103- metrics_server_config : config_lib .MetricsServerConfig | None = None ,
104- enable_jax_profiler : bool = False ,
105- jax_profiler_port : int = 9999 ,
100+ metrics_collector : JetstreamMetricsCollector | None = None ,
106101 enable_model_warmup : bool = False ,
107- ) -> JetStreamServer :
108- """Runs a server with a specified config.
102+ ):
103+ """Creates a driver with a specified config.
109104
110105 Args:
111- port: Port on which the server will be made available.
112106 config: A ServerConfig to config engine, model, device slices, etc.
113107 devices: Device objects, will be used to get engine with proper slicing.
114- credentials: Should use grpc credentials by default.
115- threads: Number of RPC handlers worker threads. This should be at least
116- equal to the decoding batch size to fully saturate the decoding queue.
117108 jax_padding: The flag to enable JAX padding during tokenization.
118- metrics_server_config: The config to enable Promethus metric server.
119- enable_jax_profiler: The flag to enable JAX profiler server.
120- jax_profiler_port: The port JAX profiler server (default to 9999).
109+ metrics_collector: The JetStream Promethus metric collector.
121110 enable_model_warmup: The flag to enable model server warmup with AOT.
122111
123112 Returns:
124- JetStreamServer that wraps the grpc server and orchestrator driver.
113+ An orchestrator driver.
125114 """
126-
127- server_start_time = time .time ()
128-
129- logging .info ("Kicking off gRPC server." )
130115 engines = config_lib .get_engines (config , devices = devices )
131116 prefill_params = [pe .load_params () for pe in engines .prefill_engines ]
132117 generate_params = [ge .load_params () for ge in engines .generate_engines ]
@@ -136,19 +121,6 @@ def run(
136121 len (config .prefill_slices ) + len (config .generate_slices ) == 0
137122 )
138123
139- # Setup Prometheus server
140- metrics_collector : JetstreamMetricsCollector = None
141- if metrics_server_config and metrics_server_config .port :
142- logging .info (
143- "Starting Prometheus server on port %d" , metrics_server_config .port
144- )
145- start_http_server (metrics_server_config .port )
146- metrics_collector = JetstreamMetricsCollector ()
147- else :
148- logging .info (
149- "Not starting Prometheus server: --prometheus_port flag not set"
150- )
151-
152124 prefill_engines = engines .prefill_engines + engines .interleaved_engines
153125 generate_engines = engines .generate_engines + engines .interleaved_engines
154126 prefill_params = prefill_params + shared_params
@@ -182,7 +154,7 @@ def run(
182154 traceback .print_exc ()
183155 os .kill (os .getpid (), signal .SIGKILL )
184156
185- driver = orchestrator .Driver (
157+ return orchestrator .Driver (
186158 prefill_engines = prefill_engines ,
187159 generate_engines = generate_engines ,
188160 prefill_params = prefill_params ,
@@ -192,6 +164,56 @@ def run(
192164 metrics_collector = metrics_collector ,
193165 is_ray_backend = config .is_ray_backend ,
194166 )
167+
168+
169+ def run (
170+ port : int ,
171+ config : Type [config_lib .ServerConfig ],
172+ devices : Any ,
173+ credentials : Any = grpc .insecure_server_credentials (),
174+ threads : int | None = None ,
175+ jax_padding : bool = True ,
176+ metrics_server_config : config_lib .MetricsServerConfig | None = None ,
177+ enable_jax_profiler : bool = False ,
178+ jax_profiler_port : int = 9999 ,
179+ enable_model_warmup : bool = False ,
180+ ) -> JetStreamServer :
181+ """Runs a server with a specified config.
182+
183+ Args:
184+ port: Port on which the server will be made available.
185+ config: A ServerConfig to config engine, model, device slices, etc.
186+ devices: Device objects, will be used to get engine with proper slicing.
187+ credentials: Should use grpc credentials by default.
188+ threads: Number of RPC handlers worker threads. This should be at least
189+ equal to the decoding batch size to fully saturate the decoding queue.
190+ jax_padding: The flag to enable JAX padding during tokenization.
191+ metrics_server_config: The config to enable Promethus metric server.
192+ enable_jax_profiler: The flag to enable JAX profiler server.
193+ jax_profiler_port: The port JAX profiler server (default to 9999).
194+ enable_model_warmup: The flag to enable model server warmup with AOT.
195+
196+ Returns:
197+ JetStreamServer that wraps the grpc server and orchestrator driver.
198+ """
199+ server_start_time = time .time ()
200+ logging .info ("Kicking off gRPC server." )
201+ # Setup Prometheus server
202+ metrics_collector : JetstreamMetricsCollector = None
203+ if metrics_server_config and metrics_server_config .port :
204+ logging .info (
205+ "Starting Prometheus server on port %d" , metrics_server_config .port
206+ )
207+ start_http_server (metrics_server_config .port )
208+ metrics_collector = JetstreamMetricsCollector ()
209+ else :
210+ logging .info (
211+ "Not starting Prometheus server: --prometheus_port flag not set"
212+ )
213+
214+ driver = create_driver (
215+ config , devices , jax_padding , metrics_collector , enable_model_warmup
216+ )
195217 # We default threads to the total number of concurrent allowed decodes,
196218 # to make sure we can fully saturate the model. Set default minimum to 64.
197219 threads = threads or max (driver .get_total_concurrent_requests (), 64 )
0 commit comments