77import time
88
99# Third Party
10- from datasets import load_dataset
10+ from datasets import Dataset , load_dataset
1111from fms .models .hf import to_hf_api
12+ from fms .models .hf .modeling_hf_adapter import HFModelArchitecture
1213from fms .utils import has_package
1314from fms .utils .tokenizers import BaseTokenizer
1415from torch import nn
3233 )
3334
3435
35- def wrap_encoder (model ) :
36+ def wrap_encoder (model : nn . Module ) -> HFModelArchitecture :
3637 """Add config info and wrapper to run pipeline for RoBERTa MaskedLM."""
3738
39+ if not has_hf :
40+ raise ImportError (
41+ "MaskedLM Encoder requires transformer package but import "
42+ "was unsuccessful."
43+ )
44+
3845 model .config .linear_config .pop ("linear_type" , None )
3946 return to_hf_api (model , task_specific_params = None )
4047
@@ -47,9 +54,9 @@ def __init__(
4754 model : nn .Module ,
4855 tokenizer : BaseTokenizer ,
4956 args : argparse .Namespace ,
50- ):
57+ ) -> None :
5158 self .model = model
52- self .tokenizer = tokenizer
59+ self .tokenizer = tokenizer . tokenizer # extract original HF tokenizer
5360 self .args = args
5461
5562 self .question_column_name = ""
@@ -59,7 +66,7 @@ def __init__(
5966
6067 self .validate_encoder_arguments ()
6168
62- def validate_encoder_arguments (self ):
69+ def validate_encoder_arguments (self ) -> None :
6370 """Ensure arguments compatibility with Encoder models."""
6471
6572 args = self .args
@@ -85,10 +92,14 @@ def validate_encoder_arguments(self):
8592 )
8693
8794
88- def prepare_validation_features (self , examples ):
95+ def prepare_validation_features (
96+ self ,
97+ examples : dict [str , list [str | dict ]],
98+ ) -> dict [str , list ]:
8999 """Validation preprocessing"""
90100
91101 args = self .args
102+
92103 q_col_name = self .question_column_name
93104 c_col_name = self .context_column_name
94105 pad_on_right = self .pad_on_right
@@ -109,7 +120,7 @@ def prepare_validation_features(self, examples):
109120 # using a stride. This results in one example possible giving several features
110121 # when a context is long, each of those features having a context that overlaps
111122 # a bit the context of the previous feature.
112- tokenized_examples = self .tokenizer . tokenize (
123+ tokenized_examples = self .tokenizer (
113124 examples [q_col_name if pad_on_right else c_col_name ],
114125 examples [c_col_name if pad_on_right else q_col_name ],
115126 truncation = "only_second" if pad_on_right else "only_first" ,
@@ -149,12 +160,15 @@ def prepare_validation_features(self, examples):
149160
150161 return tokenized_examples
151162
152- def convert_batch_to_fms_style (self , batch ):
163+ def convert_batch_to_fms_style (
164+ self ,
165+ batch : dict [str , torch .Tensor ],
166+ ) -> dict [str , torch .Tensor ]:
153167 """FMS uses a different standard than HF for encoder inputs."""
154168
155169 return {'x' : batch ['input_ids' ], 'mask' : batch ['attention_mask' ]}
156170
157- def process_eval_set (self ):
171+ def process_eval_set (self ) -> None :
158172 """Pre-process evaluation dataset for QuestionAnswering task."""
159173
160174 if not has_hf :
@@ -192,7 +206,7 @@ def process_eval_set(self):
192206 # Padding side determines if we do (question|context) or (context|question)
193207 self .pad_on_right = self .tokenizer .padding_side == "right"
194208
195- model_max_length = self .tokenizer .tokenizer . model_max_length # TODO: add model_max_length to FMS _HFTokenizer
209+ model_max_length = self .tokenizer .model_max_length
196210 if args .max_prompt_length > model_max_length :
197211 dprint (
198212 f"max_prompt_length ({ args .max_prompt_length } ) is larger than the "
@@ -259,16 +273,16 @@ def process_eval_set(self):
259273
260274 def postprocess_qa_predictions (
261275 self ,
262- examples ,
263- features ,
276+ examples : Dataset ,
277+ features : Dataset ,
264278 predictions : tuple [np .ndarray , np .ndarray ],
265279 version_2_with_negative : bool = False ,
266280 n_best_size : int = 20 ,
267281 max_answer_length : int = 30 ,
268282 null_score_diff_threshold : float = 0.0 ,
269283 output_dir : str | None = None ,
270284 prefix : str | None = None ,
271- ):
285+ ) -> None :
272286 """
273287 Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
274288 original contexts. This is the base postprocessing functions for models that only return start and end logits.
@@ -476,7 +490,13 @@ def postprocess_qa_predictions(
476490
477491 return all_predictions
478492
479- def post_processing_function (self , examples , features , predictions , stage = "eval" ):
493+ def post_processing_function (
494+ self ,
495+ examples : Dataset ,
496+ features : Dataset ,
497+ predictions : list [np .ndarray ],
498+ stage : str = "eval" ,
499+ ) -> dict [list [str , str ]]:
480500 """Post-processing: we match the start logits and end logits to answers in
481501 the original context."""
482502
@@ -492,6 +512,7 @@ def post_processing_function(self, examples, features, predictions, stage="eval"
492512 output_dir = None ,
493513 prefix = stage ,
494514 )
515+ breakpoint ()
495516 # Format the result to the format the metric expects.
496517 if args .version_2_with_negative :
497518 formatted_predictions = [
@@ -508,7 +529,12 @@ def post_processing_function(self, examples, features, predictions, stage="eval"
508529 ]
509530 return EvalPrediction (predictions = formatted_predictions , label_ids = references )
510531
511- def create_and_fill_np_array (self , start_or_end_logits , dataset , max_len ):
532+ def create_and_fill_np_array (
533+ self ,
534+ start_or_end_logits : list [np .ndarray ],
535+ dataset : Dataset ,
536+ max_len : int ,
537+ ) -> np .ndarray :
512538 """
513539 Create and fill numpy array of size
514540 len_of_validation_data * max_length_of_output_tensor
@@ -543,7 +569,7 @@ def create_and_fill_np_array(self, start_or_end_logits, dataset, max_len):
543569
544570 return logits_concat
545571
546- def run_warmup (self ):
572+ def run_warmup (self ) -> None :
547573 """Run warmup cycle of compiled encoder model set for QuestionAnswering task."""
548574
549575 dprint (f"Starting warm-up..." )
@@ -559,7 +585,7 @@ def run_warmup(self):
559585 if rank == 0 :
560586 dprint (f"Warmup completed in { time .time () - warmup_start_time :.1f} s\n ---" )
561587
562- def run_evaluation (self ):
588+ def run_evaluation (self ) -> None :
563589 """Run QuestionAnswering evaluation."""
564590
565591 args = self .args
@@ -587,7 +613,7 @@ def run_evaluation(self):
587613 f"(tot = { len (eval_dataloader ) * args .batch_size } , "
588614 f"bs = { args .batch_size } )"
589615 )
590-
616+ breakpoint ()
591617 # concatenate the numpy array
592618 max_len = max ([x .shape [1 ] for x in all_start_logits ])
593619 start_logits_concat = self .create_and_fill_np_array (
@@ -622,21 +648,27 @@ class EncoderMLMInfer():
622648
623649 def __init__ (
624650 self ,
625- model : nn . Module ,
651+ model : HFModelArchitecture ,
626652 tokenizer : BaseTokenizer ,
627653 args : argparse .Namespace ,
628- ):
654+ ) -> None :
629655 self .model = model
630656 self .tokenizer = tokenizer
631657 self .args = args
632658
633659
634- def process_eval_set (self ):
660+ def process_eval_set (self ) -> None :
635661 """Barebone function that sets up a single example prompt (for now)."""
636662
663+ if not has_hf :
664+ raise ImportError (
665+ "MaskedLM Encoder requires transformer package but import "
666+ "was unsuccessful."
667+ )
668+
637669 self .prompt = "the dog chased the cat while<mask> aggressively"
638670
639- def run_evaluation (self , warmup = False ):
671+ def run_evaluation (self , warmup : bool = False ) -> None :
640672 """Run evaluation cycle of compiled encoder model set for MaskedLM task.
641673 No output printout if warmup is True.
642674 """
@@ -658,10 +690,10 @@ def run_evaluation(self, warmup=False):
658690
659691
660692def run_encoder_eval_qa (
661- model : nn .Module ,
693+ model : nn .Module , # FMS-style model
662694 tokenizer : BaseTokenizer ,
663695 args : argparse .Namespace ,
664- ):
696+ ) -> None :
665697 """Entry point to run QuestionAnswering Evaluation of encoder model.
666698
667699 Processing based on pytorch example:
@@ -677,10 +709,10 @@ def run_encoder_eval_qa(
677709
678710
679711def run_encoder_eval_mlm (
680- model : nn . Module ,
712+ model : HFModelArchitecture , # model wrapped by to_hf_api
681713 tokenizer : BaseTokenizer ,
682714 args : argparse .Namespace ,
683- ):
715+ ) -> None :
684716 """Entry point to run evaluation of encoder models."""
685717
686718 encoder_mlm_infer = EncoderMLMInfer (model , tokenizer , args )
0 commit comments