@@ -10,24 +10,24 @@ class BaseResponse(BaseModel):
1010
1111class GenerateRequest (BaseModel ):
1212 text : List [str ] = None
13- min_length : int = None
14- do_sample : bool = None
15- early_stopping : bool = None
16- temperature : float = None
17- top_k : int = None
18- top_p : float = None
19- typical_p : float = None
20- repetition_penalty : float = None
13+ min_length : int = 0
14+ do_sample : bool = False
15+ early_stopping : bool = False
16+ temperature : float = 1
17+ top_k : int = 50
18+ top_p : float = 1
19+ typical_p : float = 1
20+ repetition_penalty : float = 1
2121 bos_token_id : int = None
2222 pad_token_id : int = None
2323 eos_token_id : int = None
24- length_penalty : float = None
25- no_repeat_ngram_size : int = None
26- encoder_no_repeat_ngram_size : int = None
24+ length_penalty : float = 1
25+ no_repeat_ngram_size : int = 0
26+ encoder_no_repeat_ngram_size : int = 0
2727 max_time : float = None
2828 max_new_tokens : int = None
2929 decoder_start_token_id : int = None
30- diversity_penalty : float = None
30+ diversity_penalty : float = 0
3131 forced_bos_token_id : int = None
3232 forced_eos_token_id : int = None
3333 exponential_decay_length_penalty : float = None
@@ -89,32 +89,51 @@ def parse_field(kwargs: dict, field: str, dtype: type, default_value: Any = None
8989
9090def create_generate_request (text : List [str ], generate_kwargs : dict ) -> GenerateRequest :
9191 # get user generate_kwargs as json and parse it
92+ default_request = GenerateRequest ()
93+
9294 return GenerateRequest (
9395 text = text ,
94- min_length = parse_field (generate_kwargs , "min_length" , int ),
95- do_sample = parse_field (generate_kwargs , "do_sample" , bool ),
96- early_stopping = parse_field (generate_kwargs , "early_stopping" , bool ),
97- num_beams = parse_field (generate_kwargs , "num_beams" , int ),
98- temperature = parse_field (generate_kwargs , "temperature" , float ),
99- top_k = parse_field (generate_kwargs , "top_k" , int ),
100- top_p = parse_field (generate_kwargs , "top_p" , float ),
101- typical_p = parse_field (generate_kwargs , "typical_p" , float ),
102- repetition_penalty = parse_field (generate_kwargs , "repetition_penalty" , float ),
103- bos_token_id = parse_field (generate_kwargs , "bos_token_id" , int ),
104- pad_token_id = parse_field (generate_kwargs , "pad_token_id" , int ),
105- eos_token_id = parse_field (generate_kwargs , "eos_token_id" , int ),
106- length_penalty = parse_field (generate_kwargs , "length_penalty" , float ),
107- no_repeat_ngram_size = parse_field (generate_kwargs , "no_repeat_ngram_size" , int ),
108- encoder_no_repeat_ngram_size = parse_field (generate_kwargs , "encoder_no_repeat_ngram_size" , int ),
109- max_time = parse_field (generate_kwargs , "max_time" , float ),
110- max_new_tokens = parse_field (generate_kwargs , "max_new_tokens" , int ),
111- decoder_start_token_id = parse_field (generate_kwargs , "decoder_start_token_id" , int ),
112- num_beam_group = parse_field (generate_kwargs , "num_beam_group" , int ),
113- diversity_penalty = parse_field (generate_kwargs , "diversity_penalty" , float ),
114- forced_bos_token_id = parse_field (generate_kwargs , "forced_bos_token_id" , int ),
115- forced_eos_token_id = parse_field (generate_kwargs , "forced_eos_token_id" , int ),
116- exponential_decay_length_penalty = parse_field (generate_kwargs , "exponential_decay_length_penalty" , float ),
117- remove_input_from_output = parse_field (generate_kwargs , "remove_input_from_output" , bool , False ),
96+ min_length = parse_field (generate_kwargs , "min_length" , int , default_request .min_length ),
97+ do_sample = parse_field (generate_kwargs , "do_sample" , bool , default_request .do_sample ),
98+ early_stopping = parse_field (generate_kwargs , "early_stopping" , bool , default_request .early_stopping ),
99+ temperature = parse_field (generate_kwargs , "temperature" , float , default_request .temperature ),
100+ top_k = parse_field (generate_kwargs , "top_k" , int , default_request .top_k ),
101+ top_p = parse_field (generate_kwargs , "top_p" , float , default_request .top_p ),
102+ typical_p = parse_field (generate_kwargs , "typical_p" , float , default_request .typical_p ),
103+ repetition_penalty = parse_field (
104+ generate_kwargs , "repetition_penalty" , float , default_request .repetition_penalty
105+ ),
106+ bos_token_id = parse_field (generate_kwargs , "bos_token_id" , int , default_request .bos_token_id ),
107+ pad_token_id = parse_field (generate_kwargs , "pad_token_id" , int , default_request .pad_token_id ),
108+ eos_token_id = parse_field (generate_kwargs , "eos_token_id" , int , default_request .eos_token_id ),
109+ length_penalty = parse_field (generate_kwargs , "length_penalty" , float , default_request .length_penalty ),
110+ no_repeat_ngram_size = parse_field (
111+ generate_kwargs , "no_repeat_ngram_size" , int , default_request .no_repeat_ngram_size
112+ ),
113+ encoder_no_repeat_ngram_size = parse_field (
114+ generate_kwargs , "encoder_no_repeat_ngram_size" , int , default_request .encoder_no_repeat_ngram_size
115+ ),
116+ max_time = parse_field (generate_kwargs , "max_time" , float , default_request .max_time ),
117+ max_new_tokens = parse_field (generate_kwargs , "max_new_tokens" , int , default_request .max_new_tokens ),
118+ decoder_start_token_id = parse_field (
119+ generate_kwargs , "decoder_start_token_id" , int , default_request .decoder_start_token_id
120+ ),
121+ diversity_penalty = parse_field (generate_kwargs , "diversity_penalty" , float , default_request .diversity_penalty ),
122+ forced_bos_token_id = parse_field (
123+ generate_kwargs , "forced_bos_token_id" , int , default_request .forced_bos_token_id
124+ ),
125+ forced_eos_token_id = parse_field (
126+ generate_kwargs , "forced_eos_token_id" , int , default_request .forced_eos_token_id
127+ ),
128+ exponential_decay_length_penalty = parse_field (
129+ generate_kwargs ,
130+ "exponential_decay_length_penalty" ,
131+ float ,
132+ default_request .exponential_decay_length_penalty ,
133+ ),
134+ remove_input_from_output = parse_field (
135+ generate_kwargs , "remove_input_from_output" , bool , default_request .remove_input_from_output
136+ ),
118137 )
119138
120139
0 commit comments