-
Notifications
You must be signed in to change notification settings - Fork 0
feat: Refactor training framework with modular architecture #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
|
|
||
| dp = parallel.get("dp", 1) | ||
| tp = parallel.get("tp", 1) | ||
| pp = parallel.get("pp", {}).get("value", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果pp的类型不是 G-pipe,应该怎么办呢?
| ] | ||
|
|
||
| megatron_args = [ | ||
| "pretrain_gpt.py", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个脚本是只针对gpt类的模型么?
| torchrun_cmd = [ | ||
| "torchrun", | ||
| f"--nproc_per_node={nproc_per_node}", | ||
| "--master_port=29501" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个port最好不要硬编码
| f"--num-attention-heads={num_attention_heads}", | ||
| f"--max-position-embeddings={max_position_embeddings}", | ||
| f"--vocab-size={vocab_size}", | ||
| "--mock-data", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果我们的config里面有数据集,你这里传mock data是不是不行啊?
|
|
||
| # regex patterns | ||
| loss_pattern = re.compile(r"lm loss:\s*([+\-]?\d+(?:\.\d+)?(?:[Ee][+\-]?\d+)?)", re.IGNORECASE) | ||
| #ppl_pattern_alt = re.compile(r"lm loss PPL:\s*([+\-]?\d+(?:\.\d+)?(?:[Ee][+\-]?\d+)?)", re.IGNORECASE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个还需要不?不需要的话可以去掉
| if me: | ||
| try: | ||
| elapsed_ms = float(me.group(1)) | ||
| tokens_per_iter = mbs * seq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方是不是也要除以显卡的数量啊?
| flog.write(line) | ||
|
|
||
| # try match loss | ||
| m = loss_pattern.search(line) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方有没有跳过那些warmup的训练环节呢?
| } | ||
|
|
||
| # Start training process | ||
| print("Launching:", " ".join(cmd)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not use print for logging, we should use logger module here.
Please apply to all code in the pr.
|
|
||
| # Add data configuration | ||
| if self.config.train_dataset is None or (isinstance(self.config.train_dataset, str) and self.config.train_dataset.lower() == "mock"): | ||
| megatron_args += ["--mock-data", "--tokenizer-type", "NullTokenizer", "--vocab-size", str(self.config.vocab_size)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to use different tokenizer for different model here?
| f"--micro-batch-size={self.config.mbs}", | ||
| f"--global-batch-size={self.config.gbs}", | ||
| f"--seq-length={self.config.seq_len}", | ||
| f"--lr={self.config.lr}", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also consider decay lr here.
We could add these configs in our json file:
f"--lr_scheduler_type=cosine", # 推荐:余弦退火 (cosine) 或 线性 (linear)
f"--warmup_ratio=0.03", # 推荐:前 3% 的步数用于热身
# 或者使用步数
# f"--warmup_steps=100",
| # Add common parameters | ||
| megatron_args += [ | ||
| "--transformer-impl", "local", | ||
| "--bf16", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should get precision from our json file
| def main(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument("--config", required=True, help="path to config.json") | ||
| parser.add_argument("--framework", default="megatron", choices=["megatron", "infinitrain"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we get this config from json file?
| parser.add_argument("--config", required=True, help="path to config.json") | ||
| parser.add_argument("--framework", default="megatron", choices=["megatron", "infinitrain"], | ||
| help="training framework to use") | ||
| parser.add_argument("--gpu-platform", default="nvidia", choices=["nvidia", "other"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a device arg in config file
baominghelly
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment added in pr
zzhfz
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@baominghelly 训练脚本修改已完成,请审核最新版本。
主要更新
- 标准化输出格式:run_id/testcase符合规范
- 完整配置支持
- 已解决6条评审意见
- 成功通过Megatron-LM测试
description
Add Modular Architecture Refactoring script
evidence
./train/train.gpt.946e7e31-adaa-4421-bef1-63eb7726402f_result.json
./train/train.gpt.946e7e31-adaa-4421-bef1-63eb7726402f_train_loss.csv
./train/train.gpt.946e7e31-adaa-4421-bef1-63eb7726402f_train_ppl.csv
./train/train.gpt.946e7e31-adaa-4421-bef1-63eb7726402f_train_throughput.csv