diff --git a/xtuner/v1/module/attention/attn_outputs.py b/xtuner/v1/module/attention/attn_outputs.py index e78cf2841..239981e0e 100644 --- a/xtuner/v1/module/attention/attn_outputs.py +++ b/xtuner/v1/module/attention/attn_outputs.py @@ -1,6 +1,5 @@ -from typing import TypedDict - import torch +from typing_extensions import TypedDict class AttnOutputs(TypedDict, total=False): diff --git a/xtuner/v1/ops/attn_imp.py b/xtuner/v1/ops/attn_imp.py index c37253f4f..3d2b1e018 100644 --- a/xtuner/v1/ops/attn_imp.py +++ b/xtuner/v1/ops/attn_imp.py @@ -1,6 +1,5 @@ import traceback from functools import lru_cache -from typing import TypedDict import torch import torch.nn as nn @@ -14,6 +13,7 @@ from torch.nn.attention.flex_attention import ( flex_attention as torch_flex_attention, ) +from typing_extensions import TypedDict from transformers.models.llama.modeling_llama import repeat_kv diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index b500b53e4..9cb373576 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -1,10 +1,11 @@ import math import os -from typing import Literal, TypedDict +from typing import Literal import ray import torch from ray.actor import ActorProxy +from typing_extensions import TypedDict from xtuner.v1.data_proto.sequence_context import SequenceContext from xtuner.v1.model.compose.base import BaseComposeConfig diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 86012ca40..7b43bd34a 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -4,7 +4,7 @@ import time from itertools import chain from pathlib import Path -from typing import Dict, Iterable, List, TypeAlias, TypedDict, cast +from typing import Dict, Iterable, List, TypeAlias, cast import ray import requests @@ -16,7 +16,7 @@ from ray.actor import ActorClass, ActorProxy from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from torch.distributed.tensor import DTensor -from typing_extensions import NotRequired +from typing_extensions import NotRequired, TypedDict from transformers import AutoTokenizer from xtuner.v1.config.fsdp import FSDPConfig