diff --git a/3.test_cases/pytorch/FSDP/src/train.py b/3.test_cases/pytorch/FSDP/src/train.py index 867f7f085..82a97cb64 100644 --- a/3.test_cases/pytorch/FSDP/src/train.py +++ b/3.test_cases/pytorch/FSDP/src/train.py @@ -1,35 +1,24 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 -import datetime import functools import math -import re import time -import numpy as np import torch from torch import optim import torch.distributed as dist import torch.utils.data -import transformers -from transformers import AutoModelForCausalLM, AutoTokenizer -from datasets import load_dataset - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import MixedPrecision from torch.distributed.fsdp import ShardingStrategy from torch.distributed.fsdp import CPUOffload -from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy -from torch.utils.data import DataLoader +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy -from model_utils.concat_dataset import ConcatTokensDataset from model_utils.train_utils import (get_model_config, compute_num_params, get_transformer_layer, - get_sharding_strategy, - get_backward_fetch_policy, apply_activation_checkpoint, get_param_groups_by_weight_decay, get_logger, @@ -268,7 +257,7 @@ def main(args): val_dataloader = create_streaming_dataloader(args.dataset, args.tokenizer, name=args.dataset_config_name, - batch_size=args.train_batch_size, + batch_size=args.val_batch_size, split='validation') train(model,