@@ -168,7 +168,7 @@ impl Backend {
168168 }
169169 }
170170 for shape in shapes. iter ( ) {
171- let batch = self . create_warmup_batch ( * shape, max_token as u32 ) ;
171+ let batch = self . create_warmup_batch ( * shape, max_token as u32 , seq_bucket_size as u32 ) ;
172172 match & self . model_type {
173173 ModelType :: Classifier => self . predict ( batch) . await . map ( |_| ( ) ) ,
174174 ModelType :: Embedding ( _) => self . embed ( batch) . await . map ( |_| ( ) ) ,
@@ -179,19 +179,25 @@ impl Backend {
179179 }
180180
181181 #[ instrument( skip_all) ]
182- pub fn create_warmup_batch ( & self , shape : ( u32 , u32 ) , max_token : u32 ) -> Batch {
182+ pub fn create_warmup_batch ( & self , shape : ( u32 , u32 ) , max_token : u32 , seq_bucket_size : u32 ) -> Batch {
183183 let ( batch_size, length) = shape;
184+ let min_length = length. saturating_sub ( seq_bucket_size) . saturating_add ( 1 ) ;
185+ let tmp_length = if min_length < length {
186+ rand:: rng ( ) . random_range ( min_length..length)
187+ } else {
188+ length
189+ } ;
184190 let mut batched_input_ids = Vec :: new ( ) ;
185191 let mut batched_token_type_ids = Vec :: new ( ) ;
186192 let mut batched_position_ids = Vec :: new ( ) ;
187193 let mut cumulative_seq_lengths = Vec :: with_capacity ( batch_size as usize + 1 ) ;
188194 let mut pooled_indices = Vec :: with_capacity ( batch_size as usize ) ;
189195 cumulative_seq_lengths. push ( 0 ) ;
190- let input_ids: Vec < u32 > = ( 0 ..length )
196+ let input_ids: Vec < u32 > = ( 0 ..tmp_length )
191197 . map ( |_| rand:: rng ( ) . random_range ( 0 ..max_token) )
192198 . collect ( ) ;
193- let token_type_ids: Vec < u32 > = vec ! [ 0 ; length as usize ] ;
194- let position_ids: Vec < u32 > = ( 0 ..length ) . collect ( ) ;
199+ let token_type_ids: Vec < u32 > = vec ! [ 0 ; tmp_length as usize ] ;
200+ let position_ids: Vec < u32 > = ( 0 ..tmp_length ) . collect ( ) ;
195201 let mut current_length = 0 ;
196202 for batch_id in 0 ..batch_size {
197203 batched_input_ids. extend ( input_ids. iter ( ) . cloned ( ) ) ;
@@ -206,7 +212,7 @@ impl Backend {
206212 token_type_ids : batched_token_type_ids,
207213 position_ids : batched_position_ids,
208214 cumulative_seq_lengths,
209- max_length : length ,
215+ max_length : tmp_length ,
210216 pooled_indices,
211217 raw_indices : vec ! [ ] ,
212218 }
0 commit comments