diff --git a/.github/workflows/greeting-ainode.yml b/.github/workflows/greeting-ainode.yml new file mode 100644 index 0000000000000..152293740f6db --- /dev/null +++ b/.github/workflows/greeting-ainode.yml @@ -0,0 +1,52 @@ +name: AINode Code Style Check + +on: + push: + branches: + - master + - "rc/*" + paths: + - 'iotdb-core/ainode/**' + pull_request: + branches: + - master + - "rc/*" + paths: + - 'iotdb-core/ainode/**' + # allow manually run the action: + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +env: + MAVEN_OPTS: -Dhttp.keepAlive=false -Dmaven.wagon.http.pool=false -Dmaven.wagon.http.retryHandler.class=standard -Dmaven.wagon.http.retryHandler.count=3 + MAVEN_ARGS: --batch-mode --no-transfer-progress + +jobs: + check-style: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python 3.10 + uses: actions/setup-python@v5 + with: + python-version: "3.10" + + - name: Install dependencies + run: | + pip3 install black==25.1.0 isort==6.0.1 + - name: Check code formatting (Black) + run: | + cd iotdb-core/ainode + black --check . + continue-on-error: false + + - name: Check import order (Isort) + run: | + cd iotdb-core/ainode + isort --check-only --profile black . + continue-on-error: false \ No newline at end of file diff --git a/iotdb-core/ainode/ainode/TimerXL/__init__.py b/iotdb-core/ainode/ainode/TimerXL/__init__.py index 4b8ee97fad2be..2a1e720805f29 100644 --- a/iotdb-core/ainode/ainode/TimerXL/__init__.py +++ b/iotdb-core/ainode/ainode/TimerXL/__init__.py @@ -14,4 +14,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# \ No newline at end of file +# diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py b/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py index d374039579ac0..3c5ad7600328d 100644 --- a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py +++ b/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py @@ -17,6 +17,7 @@ # import abc import math + import torch from einops import rearrange from torch import nn @@ -41,22 +42,23 @@ def __init__(self, dim: int, num_heads: int): def forward(self, query_id, kv_id): ind = torch.eq(query_id.unsqueeze(-1), kv_id.unsqueeze(-2)) - weight = rearrange( - self.emb.weight, "two num_heads -> two num_heads 1 1") + weight = rearrange(self.emb.weight, "two num_heads -> two num_heads 1 1") bias = ~ind * weight[:1] + ind * weight[1:] return bias -def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128): +def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=128 +): relative_buckets = 0 if bidirectional: num_buckets //= 2 - relative_buckets += (relative_position > - 0).to(torch.long) * num_buckets + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets relative_position = torch.abs(relative_position) else: - relative_position = - \ - torch.min(relative_position, torch.zeros_like(relative_position)) + relative_position = -torch.min( + relative_position, torch.zeros_like(relative_position) + ) max_exact = num_buckets // 2 is_small = relative_position < max_exact @@ -66,12 +68,13 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets * (num_buckets - max_exact) ).to(torch.long) relative_position_if_large = torch.min( - relative_position_if_large, torch.full_like( - relative_position_if_large, num_buckets - 1) + relative_position_if_large, + torch.full_like(relative_position_if_large, num_buckets - 1), ) - relative_buckets += torch.where(is_small, - relative_position, relative_position_if_large) + relative_buckets += torch.where( + is_small, relative_position, relative_position_if_large + ) return relative_buckets @@ -83,11 +86,21 @@ def __init__(self, dim: int, num_heads: int): self.relative_attention_bias = nn.Embedding(self.num_buckets, 1) def forward(self, n_vars, n_tokens): - context_position = torch.arange(n_tokens, dtype=torch.long,)[:, None] - memory_position = torch.arange(n_tokens, dtype=torch.long, )[None, :] + context_position = torch.arange( + n_tokens, + dtype=torch.long, + )[:, None] + memory_position = torch.arange( + n_tokens, + dtype=torch.long, + )[None, :] relative_position = memory_position - context_position - bucket = _relative_position_bucket(relative_position=relative_position, bidirectional=False, - num_buckets=self.num_buckets, max_distance=self.max_distance).to(self.relative_attention_bias.weight.device) + bucket = _relative_position_bucket( + relative_position=relative_position, + bidirectional=False, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ).to(self.relative_attention_bias.weight.device) bias = self.relative_attention_bias(bucket).squeeze(-1) bias = bias.reshape(1, 1, bias.shape[0], bias.shape[1]) mask1 = torch.ones((n_vars, n_vars), dtype=torch.bool).to(bias.device) diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py b/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py index f53c849b441a1..18e2b29c3d6e6 100644 --- a/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py +++ b/iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py @@ -16,8 +16,9 @@ # under the License. # import abc -import torch from functools import cached_property + +import torch from einops import einsum, rearrange, repeat from torch import nn @@ -33,7 +34,9 @@ def forward(self, x, seq_id): ... class RotaryProjection(Projection): - def __init__(self, *, proj_width: int, num_heads: int, max_len: int = 512, base: int = 10000): + def __init__( + self, *, proj_width: int, num_heads: int, max_len: int = 512, base: int = 10000 + ): super().__init__(proj_width, num_heads) assert ( self.proj_width % 2 == 0 @@ -57,8 +60,7 @@ def _init_freq(self, max_len: int): position = torch.arange( max_len, device=self.theta.device, dtype=self.theta.dtype ) - m_theta = einsum(position, self.theta, - "length, width -> length width") + m_theta = einsum(position, self.theta, "length, width -> length width") m_theta = repeat(m_theta, "length width -> length (width 2)") self.register_buffer("cos", torch.cos(m_theta), persistent=False) self.register_buffer("sin", torch.sin(m_theta), persistent=False) @@ -76,7 +78,9 @@ def forward(self, x, seq_id): class QueryKeyProjection(nn.Module): - def __init__(self, dim: int, num_heads: int, proj_layer, kwargs=None, partial_factor=None): + def __init__( + self, dim: int, num_heads: int, proj_layer, kwargs=None, partial_factor=None + ): super().__init__() if partial_factor is not None: assert ( diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py b/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py index 33747b2fe60ba..8c3cf570bafe5 100644 --- a/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py +++ b/iotdb-core/ainode/ainode/TimerXL/layers/Embed.py @@ -16,11 +16,14 @@ # under the License. # import math + import torch import torch.nn as nn from torch.jit import is_scripting + from ainode.TimerXL.models.configuration_timer import TimerxlConfig + class PositionalEmbedding(nn.Module): def __init__(self, d_model, max_len=6500): super(PositionalEmbedding, self).__init__() @@ -29,29 +32,37 @@ def __init__(self, d_model, max_len=6500): pe.require_grad = False position = torch.arange(0, max_len).float().unsqueeze(1) - div_term = (torch.arange(0, d_model, 2).float() - * -(math.log(10000.0) / d_model)).exp() + div_term = ( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ).exp() pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) def forward(self, x): - return self.pe[:, :x.size(1)] + return self.pe[:, : x.size(1)] class TokenEmbedding(nn.Module): def __init__(self, c_in, d_model): super(TokenEmbedding, self).__init__() - padding = 1 if torch.__version__ >= '1.5.0' else 2 - self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, - kernel_size=3, padding=padding, padding_mode='circular', bias=False) + padding = 1 if torch.__version__ >= "1.5.0" else 2 + self.tokenConv = nn.Conv1d( + in_channels=c_in, + out_channels=d_model, + kernel_size=3, + padding=padding, + padding_mode="circular", + bias=False, + ) for m in self.modules(): if isinstance(m, nn.Conv1d): nn.init.kaiming_normal_( - m.weight, mode='fan_in', nonlinearity='leaky_relu') + m.weight, mode="fan_in", nonlinearity="leaky_relu" + ) def forward(self, x): x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) @@ -66,8 +77,9 @@ def __init__(self, c_in, d_model): w.require_grad = False position = torch.arange(0, c_in).float().unsqueeze(1) - div_term = (torch.arange(0, d_model, 2).float() - * -(math.log(10000.0) / d_model)).exp() + div_term = ( + torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model) + ).exp() w[:, 0::2] = torch.sin(position * div_term) w[:, 1::2] = torch.cos(position * div_term) @@ -80,7 +92,7 @@ def forward(self, x): class TemporalEmbedding(nn.Module): - def __init__(self, d_model, embed_type='fixed', freq='h'): + def __init__(self, d_model, embed_type="fixed", freq="h"): super(TemporalEmbedding, self).__init__() minute_size = 4 @@ -89,8 +101,8 @@ def __init__(self, d_model, embed_type='fixed', freq='h'): day_size = 32 month_size = 13 - Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding - if freq == 't': + Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding + if freq == "t": self.minute_embed = Embed(minute_size, d_model) self.hour_embed = Embed(hour_size, d_model) self.weekday_embed = Embed(weekday_size, d_model) @@ -99,8 +111,9 @@ def __init__(self, d_model, embed_type='fixed', freq='h'): def forward(self, x): x = x.long() - minute_x = self.minute_embed(x[:, :, 4]) if hasattr( - self, 'minute_embed') else 0. + minute_x = ( + self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0 + ) hour_x = self.hour_embed(x[:, :, 3]) weekday_x = self.weekday_embed(x[:, :, 2]) day_x = self.day_embed(x[:, :, 1]) @@ -110,11 +123,10 @@ def forward(self, x): class TimeFeatureEmbedding(nn.Module): - def __init__(self, d_model, embed_type='timeF', freq='h'): + def __init__(self, d_model, embed_type="timeF", freq="h"): super(TimeFeatureEmbedding, self).__init__() - freq_map = {'h': 4, 't': 5, 's': 6, - 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} + freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3} d_inp = freq_map[freq] self.embed = nn.Linear(d_inp, d_model, bias=False) @@ -123,27 +135,32 @@ def forward(self, x): class DataEmbedding(nn.Module): - def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): super(DataEmbedding, self).__init__() self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) self.position_embedding = PositionalEmbedding(d_model=d_model) - self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, - freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( - d_model=d_model, embed_type=embed_type, freq=freq) + self.temporal_embedding = ( + TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + if embed_type != "timeF" + else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + ) self.dropout = nn.Dropout(p=dropout) def forward(self, x, x_mark): if x_mark is None: x = self.value_embedding(x) + self.position_embedding(x) else: - x = self.value_embedding( - x) + self.temporal_embedding(x_mark) + self.position_embedding(x) + x = ( + self.value_embedding(x) + + self.temporal_embedding(x_mark) + + self.position_embedding(x) + ) return self.dropout(x) class DataEmbedding_inverted(nn.Module): - def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): super(DataEmbedding_inverted, self).__init__() self.value_embedding = nn.Linear(c_in, d_model) self.dropout = nn.Dropout(p=dropout) @@ -154,21 +171,22 @@ def forward(self, x, x_mark): if x_mark is None: x = self.value_embedding(x) else: - x = self.value_embedding( - torch.cat([x, x_mark.permute(0, 2, 1)], 1)) + x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1)) # x: [Batch Variate d_model] return self.dropout(x) class DataEmbedding_wo_pos(nn.Module): - def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): + def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1): super(DataEmbedding_wo_pos, self).__init__() self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) self.position_embedding = PositionalEmbedding(d_model=d_model) - self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, - freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding( - d_model=d_model, embed_type=embed_type, freq=freq) + self.temporal_embedding = ( + TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + if embed_type != "timeF" + else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + ) self.dropout = nn.Dropout(p=dropout) def forward(self, x, x_mark): @@ -205,19 +223,21 @@ def forward(self, x): # Input encoding x = self.value_embedding(x) + self.position_embedding(x) return self.dropout(x), n_vars - + + class TimerPatchEmbedding(nn.Module): def __init__(self, config: TimerxlConfig): super().__init__() self.input_token_len = config.input_token_len - self.emb = nn.Linear(config.input_token_len, - config.hidden_size, bias=False) + self.emb = nn.Linear(config.input_token_len, config.hidden_size, bias=False) def forward(self, hidden_state: torch.Tensor): hidden_state = hidden_state.unfold( - dimension=-1, size=self.input_token_len, step=self.input_token_len) + dimension=-1, size=self.input_token_len, step=self.input_token_len + ) return self.emb(hidden_state) - + + class TimeMoeRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None): super().__init__() @@ -225,19 +245,29 @@ def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None): self.max_position_embeddings = max_position_embeddings self.base = base self.max_seq_len_cached: int = 0 - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, - 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) + / self.dim + ) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), ) - def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): + def _set_cos_sin_cache( + self, seq_len: int, device: torch.device, dtype: torch.dtype + ): self.max_seq_len_cached = int(seq_len) - t = torch.arange(self.max_seq_len_cached, device=device, - dtype=torch.int64).type_as(self.inv_freq) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=torch.int64 + ).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation @@ -249,11 +279,10 @@ def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dt self.cos_cached = emb.cos().to(dtype) self.sin_cached = emb.sin().to(dtype) - def forward(self, x, seq_len: int=0): + def forward(self, x, seq_len: int = 0): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache( - seq_len=seq_len, device=x.device, dtype=x.dtype) + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py b/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py index 82ac54fffce3e..4a2fb0d27e09f 100644 --- a/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py +++ b/iotdb-core/ainode/ainode/TimerXL/layers/SelfAttention_Family.py @@ -15,25 +15,33 @@ # specific language governing permissions and limitations # under the License. # +from math import sqrt +from typing import Any, Optional, Tuple + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Tuple, Any -import numpy as np -from math import sqrt from einops import repeat + +from ainode.core.util.huggingface_cache import Cache, DynamicCache +from ainode.core.util.masking import ( + TimerCovariateMask, + TimerMultivariateMask, + TriangularCausalMask, +) from ainode.TimerXL.layers.Attn_Bias import BinaryAttentionBias from ainode.TimerXL.layers.Attn_Projection import QueryKeyProjection, RotaryProjection from ainode.TimerXL.layers.Embed import TimeMoeRotaryEmbedding -from ainode.core.util.masking import TriangularCausalMask, TimerMultivariateMask, TimerCovariateMask from ainode.TimerXL.models.configuration_timer import TimerxlConfig -from ainode.core.util.huggingface_cache import Cache, DynamicCache + def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): cos = cos[position_ids].unsqueeze(unsqueeze_dim) sin = sin[position_ids].unsqueeze(unsqueeze_dim) @@ -41,18 +49,31 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed + class FullAttention(nn.Module): - def __init__(self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False): + def __init__( + self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False + ): super(FullAttention, self).__init__() self.scale = scale self.mask_flag = mask_flag self.output_attention = output_attention self.dropout = nn.Dropout(attention_dropout) - def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None): + def forward( + self, + queries, + keys, + values, + attn_mask, + n_vars=None, + n_tokens=None, + tau=None, + delta=None, + ): B, L, H, E = queries.shape _, S, _, D = values.shape - scale = self.scale or 1. / sqrt(E) + scale = self.scale or 1.0 / sqrt(E) scores = torch.einsum("blhe,bshe->bhls", queries, keys) @@ -84,14 +105,15 @@ def __init__(self, config: TimerxlConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.rotary_emb = TimeMoeRotaryEmbedding( - self.head_dim, max_position_embeddings=config.max_position_embeddings) + self.head_dim, max_position_embeddings=config.max_position_embeddings + ) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional["Cache"] = None, + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, ) -> Tuple[torch.Tensor, Optional["Cache"]]: bsz, q_len, _ = hidden_states.size() @@ -100,26 +122,35 @@ def forward( value_states = self.v_proj(hidden_states) query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) key_states = key_states.view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) value_states = value_states.view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length( - kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids) - + query_states, key_states, cos, sin, position_ids + ) + if past_key_value is not None: key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx) - + key_states, value_states, self.layer_idx + ) + attn_output = F.scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, dropout_p=self.attention_dropout) + query_states, + key_states, + value_states, + attention_mask, + dropout_p=self.attention_dropout, + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -142,7 +173,17 @@ def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): self.out_projection = nn.Linear(d_values * n_heads, d_model) self.n_heads = n_heads - def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, tau=None, delta=None): + def forward( + self, + queries, + keys, + values, + attn_mask, + n_vars=None, + n_tokens=None, + tau=None, + delta=None, + ): B, L, _ = queries.shape _, S, _ = keys.shape H = self.n_heads @@ -159,9 +200,8 @@ def forward(self, queries, keys, values, attn_mask, n_vars=None, n_tokens=None, n_vars=n_vars, n_tokens=n_tokens, tau=tau, - delta=delta + delta=delta, ) out = out.view(B, L, -1) return self.out_projection(out), attn - diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py b/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py index ceb3d7ebe1653..d5bad30ea0559 100644 --- a/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py +++ b/iotdb-core/ainode/ainode/TimerXL/layers/Transformer_EncDec.py @@ -15,35 +15,32 @@ # specific language governing permissions and limitations # under the License. # +from typing import Optional, Tuple + import torch import torch.nn as nn import torch.nn.functional as F -from typing import Optional, Tuple -from ainode.TimerXL.models.configuration_timer import TimerxlConfig -from ainode.TimerXL.layers.SelfAttention_Family import TimerAttention + from ainode.core.util.activation import ACT2FN from ainode.core.util.huggingface_cache import Cache, DynamicCache +from ainode.TimerXL.layers.SelfAttention_Family import TimerAttention +from ainode.TimerXL.models.configuration_timer import TimerxlConfig + class EncoderLayer(nn.Module): def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"): super(EncoderLayer, self).__init__() d_ff = d_ff or 4 * d_model self.attention = attention - self.conv1 = nn.Conv1d(in_channels=d_model, - out_channels=d_ff, kernel_size=1) - self.conv2 = nn.Conv1d( - in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = F.relu if activation == "relu" else F.gelu def forward(self, x, attn_mask=None, tau=None, delta=None): - new_x, attn = self.attention( - x, x, x, - attn_mask=attn_mask, - tau=tau, delta=delta - ) + new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta) x = x + self.dropout(new_x) y = x = self.norm1(x) @@ -54,16 +51,21 @@ def forward(self, x, attn_mask=None, tau=None, delta=None): class DecoderLayer(nn.Module): - def __init__(self, self_attention, cross_attention, d_model, d_ff=None, - dropout=0.1, activation="relu"): + def __init__( + self, + self_attention, + cross_attention, + d_model, + d_ff=None, + dropout=0.1, + activation="relu", + ): super(DecoderLayer, self).__init__() d_ff = d_ff or 4 * d_model self.self_attention = self_attention self.cross_attention = cross_attention - self.conv1 = nn.Conv1d(in_channels=d_model, - out_channels=d_ff, kernel_size=1) - self.conv2 = nn.Conv1d( - in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) @@ -71,18 +73,16 @@ def __init__(self, self_attention, cross_attention, d_model, d_ff=None, self.activation = F.relu if activation == "relu" else F.gelu def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): - x = x + self.dropout(self.self_attention( - x, x, x, - attn_mask=x_mask, - tau=tau, delta=None - )[0]) + x = x + self.dropout( + self.self_attention(x, x, x, attn_mask=x_mask, tau=tau, delta=None)[0] + ) x = self.norm1(x) - x = x + self.dropout(self.cross_attention( - x, cross, cross, - attn_mask=cross_mask, - tau=tau, delta=delta - )[0]) + x = x + self.dropout( + self.cross_attention( + x, cross, cross, attn_mask=cross_mask, tau=tau, delta=delta + )[0] + ) y = x = self.norm2(x) y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1)))) @@ -96,21 +96,15 @@ def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu" super(DecoderOnlyLayer, self).__init__() d_ff = d_ff or 4 * d_model self.attention = attention - self.conv1 = nn.Conv1d(in_channels=d_model, - out_channels=d_ff, kernel_size=1) - self.conv2 = nn.Conv1d( - in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.activation = F.relu if activation == "relu" else F.gelu def forward(self, x, attn_mask=None, tau=None, delta=None): - new_x, attn = self.attention( - x, x, x, - attn_mask=attn_mask, - tau=tau, delta=delta - ) + new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta) x = x + self.dropout(new_x) y = x = self.norm1(x) @@ -125,10 +119,8 @@ def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu" super(TimerLayer, self).__init__() d_ff = d_ff or 4 * d_model self.attention = attention - self.conv1 = nn.Conv1d(in_channels=d_model, - out_channels=d_ff, kernel_size=1) - self.conv2 = nn.Conv1d( - in_channels=d_ff, out_channels=d_model, kernel_size=1) + self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1) + self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) @@ -136,11 +128,14 @@ def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu" def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, delta=None): new_x, attn = self.attention( - x, x, x, + x, + x, + x, n_vars=n_vars, n_tokens=n_tokens, attn_mask=attn_mask, - tau=tau, delta=delta + tau=tau, + delta=delta, ) x = x + self.dropout(new_x) @@ -155,26 +150,27 @@ class Encoder(nn.Module): def __init__(self, attn_layers, conv_layers=None, norm_layer=None): super(Encoder, self).__init__() self.attn_layers = nn.ModuleList(attn_layers) - self.conv_layers = nn.ModuleList( - conv_layers) if conv_layers is not None else None + self.conv_layers = ( + nn.ModuleList(conv_layers) if conv_layers is not None else None + ) self.norm = norm_layer def forward(self, x, attn_mask=None, tau=None, delta=None): # x [B, L, D] attns = [] if self.conv_layers is not None: - for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): + for i, (attn_layer, conv_layer) in enumerate( + zip(self.attn_layers, self.conv_layers) + ): delta = delta if i == 0 else None - x, attn = attn_layer( - x, attn_mask=attn_mask, tau=tau, delta=delta) + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) x = conv_layer(x) attns.append(attn) x, attn = self.attn_layers[-1](x, tau=tau, delta=None) attns.append(attn) else: for attn_layer in self.attn_layers: - x, attn = attn_layer( - x, attn_mask=attn_mask, tau=tau, delta=delta) + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) attns.append(attn) if self.norm is not None: @@ -192,8 +188,9 @@ def __init__(self, layers, norm_layer=None, projection=None): def forward(self, x, cross, x_mask=None, cross_mask=None, tau=None, delta=None): for layer in self.layers: - x = layer(x, cross, x_mask=x_mask, - cross_mask=cross_mask, tau=tau, delta=delta) + x = layer( + x, cross, x_mask=x_mask, cross_mask=cross_mask, tau=tau, delta=delta + ) if self.norm is not None: x = self.norm(x) @@ -207,26 +204,27 @@ class DecoderOnly(nn.Module): def __init__(self, attn_layers, conv_layers=None, norm_layer=None): super(DecoderOnly, self).__init__() self.attn_layers = nn.ModuleList(attn_layers) - self.conv_layers = nn.ModuleList( - conv_layers) if conv_layers is not None else None + self.conv_layers = ( + nn.ModuleList(conv_layers) if conv_layers is not None else None + ) self.norm = norm_layer def forward(self, x, attn_mask=None, tau=None, delta=None): # x [B, L, D] attns = [] if self.conv_layers is not None: - for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): + for i, (attn_layer, conv_layer) in enumerate( + zip(self.attn_layers, self.conv_layers) + ): delta = delta if i == 0 else None - x, attn = attn_layer( - x, attn_mask=attn_mask, tau=tau, delta=delta) + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) x = conv_layer(x) attns.append(attn) x, attn = self.attn_layers[-1](x, tau=tau, delta=None) attns.append(attn) else: for attn_layer in self.attn_layers: - x, attn = attn_layer( - x, attn_mask=attn_mask, tau=tau, delta=delta) + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) attns.append(attn) if self.norm is not None: @@ -239,27 +237,29 @@ class TimerBlock(nn.Module): def __init__(self, attn_layers, conv_layers=None, norm_layer=None): super(TimerBlock, self).__init__() self.attn_layers = nn.ModuleList(attn_layers) - self.conv_layers = nn.ModuleList( - conv_layers) if conv_layers is not None else None + self.conv_layers = ( + nn.ModuleList(conv_layers) if conv_layers is not None else None + ) self.norm = norm_layer def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, delta=None): # x [B, L, D] attns = [] if self.conv_layers is not None: - for i, (attn_layer, conv_layer) in enumerate(zip(self.attn_layers, self.conv_layers)): + for i, (attn_layer, conv_layer) in enumerate( + zip(self.attn_layers, self.conv_layers) + ): delta = delta if i == 0 else None - x, attn = attn_layer( - x, attn_mask=attn_mask, tau=tau, delta=delta) + x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta) x = conv_layer(x) attns.append(attn) - x, attn = self.attn_layers[-1](x, n_vars, - n_tokens, tau=tau, delta=None) + x, attn = self.attn_layers[-1](x, n_vars, n_tokens, tau=tau, delta=None) attns.append(attn) else: for attn_layer in self.attn_layers: - x, attn = attn_layer(x, n_vars, n_tokens, - attn_mask=attn_mask, tau=tau, delta=delta) + x, attn = attn_layer( + x, n_vars, n_tokens, attn_mask=attn_mask, tau=tau, delta=delta + ) attns.append(attn) if self.norm is not None: @@ -267,21 +267,22 @@ def forward(self, x, n_vars, n_tokens, attn_mask=None, tau=None, delta=None): return x, attns + class TimerMLP(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.gate_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[hidden_act] def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + return self.down_proj( + self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) + ) + class TimerDecoderLayer(nn.Module): def __init__(self, config: TimerxlConfig, layer_idx: int): @@ -297,12 +298,12 @@ def __init__(self, config: TimerxlConfig, layer_idx: int): self.norm2 = torch.nn.LayerNorm(config.hidden_size) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - use_cache: bool = False + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: bool = False, ) -> Tuple[torch.FloatTensor, Optional[Cache]]: residual = hidden_states @@ -313,7 +314,7 @@ def forward( position_ids=position_ids, past_key_value=past_key_value, ) - + hidden_states = residual + hidden_states hidden_states = self.norm1(hidden_states) @@ -325,4 +326,4 @@ def forward( if not use_cache: present_key_value = None - return hidden_states, present_key_value \ No newline at end of file + return hidden_states, present_key_value diff --git a/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py b/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py index 4b8ee97fad2be..2a1e720805f29 100644 --- a/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py +++ b/iotdb-core/ainode/ainode/TimerXL/layers/__init__.py @@ -14,4 +14,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# \ No newline at end of file +# diff --git a/iotdb-core/ainode/ainode/TimerXL/models/__init__.py b/iotdb-core/ainode/ainode/TimerXL/models/__init__.py index 4b8ee97fad2be..2a1e720805f29 100644 --- a/iotdb-core/ainode/ainode/TimerXL/models/__init__.py +++ b/iotdb-core/ainode/ainode/TimerXL/models/__init__.py @@ -14,4 +14,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# \ No newline at end of file +# diff --git a/iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py b/iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py index 7dbce577537c5..ac5034aa85ec9 100644 --- a/iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py +++ b/iotdb-core/ainode/ainode/TimerXL/models/configuration_timer.py @@ -17,24 +17,25 @@ # from typing import List + class TimerxlConfig: model_type = "timerxl" def __init__( self, - input_token_len: int = 96, # how many points as a token, don't change - hidden_size: int = 1024, # model hidden size - intermediate_size: int = 2048, # ffn middle size - output_token_lens: List[int] = [96],# how many points as a token, don't change + input_token_len: int = 96, # how many points as a token, don't change + hidden_size: int = 1024, # model hidden size + intermediate_size: int = 2048, # ffn middle size + output_token_lens: List[int] = [96], # how many points as a token, don't change num_hidden_layers: int = 8, num_attention_heads: int = 8, - hidden_act: str = "silu", # activation function - use_cache: bool = True, # kv cache - rope_theta: int = 10000, # ROBE parameter - attention_dropout: float = 0.0, - initializer_range: float = 0.02, # be of no use, because we already have weights + hidden_act: str = "silu", # activation function + use_cache: bool = True, # kv cache + rope_theta: int = 10000, # ROBE parameter + attention_dropout: float = 0.0, + initializer_range: float = 0.02, # be of no use, because we already have weights max_position_embeddings: int = 10000, - ckpt_path: str = None, # weight path + ckpt_path: str = None, # weight path **kwargs, ): self.input_token_len = input_token_len @@ -54,7 +55,7 @@ def __init__( super().__init__( **kwargs, ) - + @classmethod def from_dict(cls, config_dict: dict) -> "TimerxlConfig": return cls(**config_dict) diff --git a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py index 0e66542405ebf..b3962a052a527 100644 --- a/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py +++ b/iotdb-core/ainode/ainode/TimerXL/models/timer_xl.py @@ -16,36 +16,41 @@ # under the License. # -import torch import os -from torch import nn -from typing import Optional, List, Dict, Any, Tuple from dataclasses import dataclass -from ainode.TimerXL.layers.Transformer_EncDec import TimerDecoderLayer -from ainode.TimerXL.layers.Embed import TimerPatchEmbedding -from ainode.TimerXL.models.configuration_timer import TimerxlConfig -from ainode.core.util.masking import prepare_4d_causal_attention_mask -from ainode.core.util.huggingface_cache import Cache, DynamicCache +from typing import Any, Dict, List, Optional, Tuple -from safetensors.torch import load_file as load_safetensors +import torch from huggingface_hub import hf_hub_download +from safetensors.torch import load_file as load_safetensors +from torch import nn from ainode.core.log import Logger +from ainode.core.util.huggingface_cache import Cache, DynamicCache +from ainode.core.util.masking import prepare_4d_causal_attention_mask +from ainode.TimerXL.layers.Embed import TimerPatchEmbedding +from ainode.TimerXL.layers.Transformer_EncDec import TimerDecoderLayer +from ainode.TimerXL.models.configuration_timer import TimerxlConfig + logger = Logger() + @dataclass class Output: outputs: torch.Tensor past_key_values: Optional[Any] = None + class TimerModel(nn.Module): def __init__(self, config: TimerxlConfig): super().__init__() self.config = config self.embed_layer = TimerPatchEmbedding(config) self.layers = nn.ModuleList( - [TimerDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers)] + [ + TimerDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] ) self.norm = torch.nn.LayerNorm(config.hidden_size) self.gradient_checkpointing = False @@ -59,15 +64,16 @@ def forward( use_cache: bool = None, ): # input_ids is the input of time series, its shape is [batch_size, seq_len] - + if input_ids is not None: batch_size, seq_length = input_ids.shape else: raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds") + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) + + inputs_embeds = self.embed_layer(input_ids) - inputs_embeds = self.embed_layer(input_ids) - seq_length = inputs_embeds.shape[1] past_key_values_length = 0 @@ -75,15 +81,16 @@ def forward( if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache( - past_key_values) - past_key_values_length = past_key_values.get_usable_length( - seq_length) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.view(-1, seq_length) else: @@ -112,21 +119,22 @@ def forward( ) hidden_states = layer_outputs[0] - + if use_cache: next_decoder_cache = layer_outputs[1] - + hidden_states = self.norm(hidden_states) next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache( - ) if use_legacy_cache else next_decoder_cache + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + + return Output(outputs=hidden_states, past_key_values=next_cache) - return Output( - outputs=hidden_states, - past_key_values=next_cache - ) class TimerForPrediction(nn.Module): def __init__(self, config): @@ -137,11 +145,12 @@ def __init__(self, config): self.output_token_len_map = {} for i, output_token_len in enumerate(self.config.output_token_lens): lm_head_list.append( - nn.Linear(self.config.hidden_size, output_token_len, bias=False)) + nn.Linear(self.config.hidden_size, output_token_len, bias=False) + ) self.output_token_len_map[output_token_len] = i self.lm_heads = nn.ModuleList(lm_head_list) - self.loss_function = torch.nn.MSELoss(reduction='none') - + self.loss_function = torch.nn.MSELoss(reduction="none") + def forward( self, input_ids: torch.FloatTensor = None, @@ -153,9 +162,11 @@ def forward( revin: Optional[bool] = True, ): if revin: - means, stdev = input_ids.mean(dim=-1, keepdim=True), input_ids.std(dim=-1, keepdim=True) + means, stdev = input_ids.mean(dim=-1, keepdim=True), input_ids.std( + dim=-1, keepdim=True + ) input_ids = (input_ids - means) / stdev - + outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, @@ -164,7 +175,7 @@ def forward( use_cache=use_cache, ) hidden_states = outputs.outputs - + if max_output_length is None: output_token_len = self.config.output_token_lens[0] max_output_length = output_token_len @@ -175,26 +186,26 @@ def forward( break else: output_token_len = h - + lm_head = self.lm_heads[self.output_token_len_map[output_token_len]] predictions = lm_head(hidden_states)[:, -1, :] - + if output_token_len > max_output_length: predictions = predictions[:, :max_output_length] if revin: predictions = predictions * stdev + means - + return Output(predictions, outputs.past_key_values) class Model(nn.Module): """ - Timer-XL: Long-Context Transformers for Unified Time Series Forecasting + Timer-XL: Long-Context Transformers for Unified Time Series Forecasting Paper: https://arxiv.org/abs/2410.04803 - + GitHub: https://github.com/thuml/Timer-XL - + Citation: @article{liu2024timer, title={Timer-XL: Long-Context Transformers for Unified Time Series Forecasting}, author={Liu, Yong and Qin, Guo and Huang, Xiangdong and Wang, Jianmin and Long, Mingsheng}, @@ -202,105 +213,116 @@ class Model(nn.Module): year={2024} } """ + def __init__(self, config: TimerxlConfig): super().__init__() - self.config = config # can't be scripted by torch - + self.config = config # can't be scripted by torch + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = TimerForPrediction(config).to(self.device) - - if config.ckpt_path is not None and config.ckpt_path != '': - if config.ckpt_path.endswith('.pt') or config.ckpt_path.endswith('.pth'): + + if config.ckpt_path is not None and config.ckpt_path != "": + if config.ckpt_path.endswith(".pt") or config.ckpt_path.endswith(".pth"): state_dict = torch.load(config.ckpt_path) - elif config.ckpt_path.endswith('.safetensors'): + elif config.ckpt_path.endswith(".safetensors"): if not os.path.exists(config.ckpt_path): - logger.info(f"Checkpoint not found at {config.ckpt_path}, downloading from HuggingFace...") + logger.info( + f"Checkpoint not found at {config.ckpt_path}, downloading from HuggingFace..." + ) repo_id = "thuml/timer-base-84m" try: - config.ckpt_path = hf_hub_download(repo_id=repo_id, filename=os.path.basename(config.ckpt_path), local_dir=os.path.dirname(config.ckpt_path)) + config.ckpt_path = hf_hub_download( + repo_id=repo_id, + filename=os.path.basename(config.ckpt_path), + local_dir=os.path.dirname(config.ckpt_path), + ) logger.info(f"Got checkpoint to {config.ckpt_path}") except Exception as e: - logger.error(f"Failed to download checkpoint to {config.ckpt_path} due to {e}") + logger.error( + f"Failed to download checkpoint to {config.ckpt_path} due to {e}" + ) raise e state_dict = load_safetensors(config.ckpt_path) else: - raise ValueError('unsupported model weight type') + raise ValueError("unsupported model weight type") # If there is no key beginning with 'model.model' in state_dict, add a 'model.' before all keys. (The model code here has an additional layer of encapsulation compared to the code on huggingface.) - if not any(k.startswith('model.model') for k in state_dict.keys()): - state_dict = {'model.' + k: v for k, v in state_dict.items()} + if not any(k.startswith("model.model") for k in state_dict.keys()): + state_dict = {"model." + k: v for k, v in state_dict.items()} self.load_state_dict(state_dict, strict=True) - + def set_device(self, device): self.model.to(device) self.device = next(self.model.parameters()).device - + def inference(self, x, max_new_tokens: int = 96): # x.shape: [L, C], type: DataFrame # here we only except C=1 temporarily # change [L, C=1] to [batchsize=1, L] self.device = next(self.model.parameters()).device - - x = torch.tensor(x, dtype=next(self.model.parameters()).dtype, device=self.device) + + x = torch.tensor( + x, dtype=next(self.model.parameters()).dtype, device=self.device + ) x = x.view(1, -1) preds = self.forward(x, max_new_tokens) preds = preds.detach().cpu().numpy() - return preds + return preds def forward(self, x, max_new_tokens: int = 96): # self.config.is_encoder_decoder = False self.eval() self.device = next(self.model.parameters()).device - + if len(x.shape) == 2: batch_size, cur_len = x.shape if cur_len < self.config.input_token_len: raise ValueError( - f"Input length must be at least {self.config.input_token_len}") + f"Input length must be at least {self.config.input_token_len}" + ) elif cur_len % self.config.input_token_len != 0: - new_len = (cur_len // self.config.input_token_len) * \ - self.config.input_token_len + new_len = ( + cur_len // self.config.input_token_len + ) * self.config.input_token_len x = x[:, -new_len:] else: - raise ValueError('Input shape must be: [batch_size, seq_len]') - + raise ValueError("Input shape must be: [batch_size, seq_len]") + use_cache = self.config.use_cache all_input_ids = x - + attention_mask = self.prepare_attention_mask_for_generation(all_input_ids) all_input_ids_length = all_input_ids.shape[-1] max_length = max_new_tokens + all_input_ids_length - + all_input_ids = all_input_ids.to(self.device) batch_size, cur_len = all_input_ids.shape - - unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=all_input_ids.device) + + unfinished_sequences = torch.ones( + batch_size, dtype=torch.long, device=all_input_ids.device + ) cache_position = torch.arange(cur_len, device=all_input_ids.device) true_seq_len = cur_len // self.config.input_token_len attention_mask = attention_mask[:, -true_seq_len:] - + this_peer_finished = False past_key_values = None position_ids = None while not this_peer_finished: - ( - input_ids, - position_ids, - past_key_values, - attention_mask, - revin - ) = self.prepare_inputs_for_generation( - all_input_ids, - past_key_values=past_key_values, - attention_mask=attention_mask, - # position_ids=position_ids # Wrong?! - position_ids=None # True?! based on huggingface code + (input_ids, position_ids, past_key_values, attention_mask, revin) = ( + self.prepare_inputs_for_generation( + all_input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + # position_ids=position_ids # Wrong?! + position_ids=None, # True?! based on huggingface code + ) ) input_length = all_input_ids.shape[1] - - # forward pass to get next token + + # forward pass to get next token outputs = self.model( input_ids, attention_mask=attention_mask, @@ -308,42 +330,47 @@ def forward(self, x, max_new_tokens: int = 96): past_key_values=past_key_values, use_cache=use_cache, max_output_length=max_length - input_length, - revin=revin + revin=revin, ) - + next_tokens = outputs.outputs - + # update generated ids, model inputs, and length for next step horizon_length = next_tokens.shape[1] // self.config.input_token_len - + all_input_ids = torch.cat([all_input_ids, next_tokens], dim=-1) - ( - past_key_values, - attention_mask, - cache_position - ) = self._update_model_kwargs_for_generation( - outputs, - attention_mask=attention_mask, - horizon_length=horizon_length, - cache_position=cache_position, + (past_key_values, attention_mask, cache_position) = ( + self._update_model_kwargs_for_generation( + outputs, + attention_mask=attention_mask, + horizon_length=horizon_length, + cache_position=cache_position, + ) + ) + + unfinished_sequences = unfinished_sequences & ( + all_input_ids.shape[1] < max_length ) - - unfinished_sequences = unfinished_sequences & (all_input_ids.shape[1] < max_length) this_peer_finished = unfinished_sequences.max() == 0 - + if all_input_ids.shape[1] > max_length: all_input_ids = all_input_ids[:, :max_length] - - return all_input_ids[:, -(max_length - cur_len):] - + + return all_input_ids[:, -(max_length - cur_len) :] + def prepare_attention_mask_for_generation( self, inputs: torch.Tensor, ) -> torch.LongTensor: return torch.ones(inputs.shape[:2], dtype=torch.long, device=inputs.device) - + def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, revin=True, position_ids=None + self, + input_ids, + past_key_values=None, + attention_mask=None, + revin=True, + position_ids=None, ): # Omit tokens covered by past_key_values if past_key_values is not None: @@ -363,21 +390,22 @@ def prepare_inputs_for_generation( # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) - if attention_mask is not None and attention_mask.shape[1] > (input_ids.shape[1] // self.config.input_token_len): - input_ids = input_ids[:, - - (attention_mask.shape[1] - past_length):] + if attention_mask is not None and attention_mask.shape[1] > ( + input_ids.shape[1] // self.config.input_token_len + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): - input_ids = input_ids[:, past_length * - self.config.input_token_len:] + input_ids = input_ids[:, past_length * self.config.input_token_len :] # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + (input_ids.shape[1] // self.config.input_token_len) > max_cache_length + max_cache_length is not None + and attention_mask is not None + and cache_length + (input_ids.shape[1] // self.config.input_token_len) + > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] @@ -386,17 +414,18 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, - - (input_ids.shape[1] // self.config.input_token_len):] + position_ids = position_ids[ + :, -(input_ids.shape[1] // self.config.input_token_len) : + ] return (input_ids, position_ids, past_key_values, attention_mask, revin) - + def _update_model_kwargs_for_generation( - self, - outputs, - attention_mask = None, - cache_position = None, - horizon_length: int = 1, + self, + outputs, + attention_mask=None, + cache_position=None, + horizon_length: int = 1, ) -> Dict[str, Any]: # update past_key_values past_key_values = outputs.past_key_values @@ -404,10 +433,14 @@ def _update_model_kwargs_for_generation( # update attention mask if attention_mask is not None: attention_mask = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1 + [ + attention_mask, + attention_mask.new_ones((attention_mask.shape[0], horizon_length)), + ], + dim=-1, ) if cache_position is not None: cache_position = cache_position[-1:] + horizon_length - return (past_key_values, attention_mask, cache_position) \ No newline at end of file + return (past_key_values, attention_mask, cache_position) diff --git a/iotdb-core/ainode/ainode/core/client.py b/iotdb-core/ainode/ainode/core/client.py index 2796d46ab930d..15385928a84a6 100644 --- a/iotdb-core/ainode/ainode/core/client.py +++ b/iotdb-core/ainode/ainode/core/client.py @@ -17,8 +17,8 @@ # import time +from thrift.protocol import TBinaryProtocol, TCompactProtocol from thrift.Thrift import TException -from thrift.protocol import TCompactProtocol, TBinaryProtocol from thrift.transport import TSocket, TTransport from ainode.core.config import AINodeDescriptor @@ -26,11 +26,20 @@ from ainode.core.log import Logger from ainode.core.util.decorator import singleton from ainode.core.util.status import verify_success -from ainode.thrift.common.ttypes import TEndPoint, TSStatus, TAINodeLocation, TAINodeConfiguration +from ainode.thrift.common.ttypes import ( + TAINodeConfiguration, + TAINodeLocation, + TEndPoint, + TSStatus, +) from ainode.thrift.confignode import IConfigNodeRPCService -from ainode.thrift.confignode.ttypes import (TAINodeRemoveReq, TNodeVersionInfo, - TAINodeRegisterReq, TAINodeRestartReq) -from ainode.thrift.confignode.ttypes import TUpdateModelInfoReq +from ainode.thrift.confignode.ttypes import ( + TAINodeRegisterReq, + TAINodeRemoveReq, + TAINodeRestartReq, + TNodeVersionInfo, + TUpdateModelInfoReq, +) logger = Logger() @@ -38,7 +47,9 @@ @singleton class ClientManager(object): def __init__(self): - self._config_node_endpoint = AINodeDescriptor().get_config().get_ain_target_config_node_list() + self._config_node_endpoint = ( + AINodeDescriptor().get_config().get_ain_target_config_node_list() + ) def borrow_config_node_client(self): return ConfigNodeClient(config_leader=self._config_node_endpoint) @@ -52,7 +63,9 @@ def __init__(self, config_leader: TEndPoint): self._transport = None self._client = None - self._MSG_RECONNECTION_FAIL = "Fail to connect to any config node. Please check status of ConfigNodes" + self._MSG_RECONNECTION_FAIL = ( + "Fail to connect to any config node. Please check status of ConfigNodes" + ) self._RETRY_NUM = 5 self._RETRY_INTERVAL_MS = 1 @@ -64,7 +77,10 @@ def _try_to_connect(self) -> None: self._connect(self._config_leader) return except TException: - logger.warning("The current node {} may have been down, try next node", self._config_leader) + logger.warning( + "The current node {} may have been down, try next node", + self._config_leader, + ) self._config_leader = None if self._transport is not None: @@ -79,7 +95,10 @@ def _try_to_connect(self) -> None: self._connect(try_endpoint) return except TException: - logger.warning("The current node {} may have been down, try next node", try_endpoint) + logger.warning( + "The current node {} may have been down, try next node", + try_endpoint, + ) try_host_num = try_host_num + 1 @@ -126,65 +145,83 @@ def _update_config_node_leader(self, status: TSStatus) -> bool: return True return False - def node_register(self, cluster_name: str, configuration: TAINodeConfiguration, - version_info: TNodeVersionInfo) -> int: + def node_register( + self, + cluster_name: str, + configuration: TAINodeConfiguration, + version_info: TNodeVersionInfo, + ) -> int: req = TAINodeRegisterReq( clusterName=cluster_name, aiNodeConfiguration=configuration, - versionInfo=version_info + versionInfo=version_info, ) for _ in range(0, self._RETRY_NUM): try: resp = self._client.registerAINode(req) if not self._update_config_node_leader(resp.status): - verify_success(resp.status, "An error occurs when calling node_register()") + verify_success( + resp.status, "An error occurs when calling node_register()" + ) self._config_nodes = resp.configNodeList return resp.aiNodeId except TTransport.TException: - logger.warning("Failed to connect to ConfigNode {} from AINode when executing node_register()", - self._config_leader) + logger.warning( + "Failed to connect to ConfigNode {} from AINode when executing node_register()", + self._config_leader, + ) self._config_leader = None self._wait_and_reconnect() raise TException(self._MSG_RECONNECTION_FAIL) - def node_restart(self, cluster_name: str, configuration: TAINodeConfiguration, - version_info: TNodeVersionInfo) -> None: + def node_restart( + self, + cluster_name: str, + configuration: TAINodeConfiguration, + version_info: TNodeVersionInfo, + ) -> None: req = TAINodeRestartReq( clusterName=cluster_name, aiNodeConfiguration=configuration, - versionInfo=version_info + versionInfo=version_info, ) for _ in range(0, self._RETRY_NUM): try: resp = self._client.restartAINode(req) if not self._update_config_node_leader(resp.status): - verify_success(resp.status, "An error occurs when calling node_restart()") + verify_success( + resp.status, "An error occurs when calling node_restart()" + ) self._config_nodes = resp.configNodeList return resp.status except TTransport.TException: - logger.warning("Failed to connect to ConfigNode {} from AINode when executing node_restart()", - self._config_leader) + logger.warning( + "Failed to connect to ConfigNode {} from AINode when executing node_restart()", + self._config_leader, + ) self._config_leader = None self._wait_and_reconnect() raise TException(self._MSG_RECONNECTION_FAIL) def node_remove(self, location: TAINodeLocation): - req = TAINodeRemoveReq( - aiNodeLocation=location - ) + req = TAINodeRemoveReq(aiNodeLocation=location) for _ in range(0, self._RETRY_NUM): try: status = self._client.removeAINode(req) if not self._update_config_node_leader(status): - verify_success(status, "An error occurs when calling node_restart()") + verify_success( + status, "An error occurs when calling node_restart()" + ) return status except TTransport.TException: - logger.warning("Failed to connect to ConfigNode {} from AINode when executing node_remove()", - self._config_leader) + logger.warning( + "Failed to connect to ConfigNode {} from AINode when executing node_remove()", + self._config_leader, + ) self._config_leader = None self._wait_and_reconnect() raise TException(self._MSG_RECONNECTION_FAIL) @@ -194,35 +231,50 @@ def get_ainode_configuration(self, node_id: int) -> map: try: resp = self._client.getAINodeConfiguration(node_id) if not self._update_config_node_leader(resp.status): - verify_success(resp.status, "An error occurs when calling get_ainode_configuration()") + verify_success( + resp.status, + "An error occurs when calling get_ainode_configuration()", + ) return resp.aiNodeConfigurationMap except TTransport.TException: - logger.warning("Failed to connect to ConfigNode {} from AINode when executing " - "get_ainode_configuration()", - self._config_leader) + logger.warning( + "Failed to connect to ConfigNode {} from AINode when executing " + "get_ainode_configuration()", + self._config_leader, + ) self._config_leader = None self._wait_and_reconnect() raise TException(self._MSG_RECONNECTION_FAIL) - def update_model_info(self, model_id:str, model_status:int, attribute:str = "", ainode_id=None, input_length=0, output_length=0) -> None: + def update_model_info( + self, + model_id: str, + model_status: int, + attribute: str = "", + ainode_id=None, + input_length=0, + output_length=0, + ) -> None: if ainode_id is None: ainode_id = [] for _ in range(0, self._RETRY_NUM): try: - req = TUpdateModelInfoReq( - model_id, model_status, attribute - ) + req = TUpdateModelInfoReq(model_id, model_status, attribute) if ainode_id is not None: req.aiNodeIds = ainode_id req.inputLength = input_length req.outputLength = output_length status = self._client.updateModelInfo(req) if not self._update_config_node_leader(status): - verify_success(status, "An error occurs when calling update model info") + verify_success( + status, "An error occurs when calling update model info" + ) return status except TTransport.TException: - logger.warning("Failed to connect to ConfigNode {} from AINode when executing update model info", - self._config_leader) + logger.warning( + "Failed to connect to ConfigNode {} from AINode when executing update model info", + self._config_leader, + ) self._config_leader = None self._wait_and_reconnect() raise TException(self._MSG_RECONNECTION_FAIL) diff --git a/iotdb-core/ainode/ainode/core/config.py b/iotdb-core/ainode/ainode/core/config.py index 036dc20b9c224..62de76fcbb839 100644 --- a/iotdb-core/ainode/ainode/core/config.py +++ b/iotdb-core/ainode/ainode/core/config.py @@ -17,13 +17,24 @@ # import os -from ainode.core.constant import (AINODE_CONF_DIRECTORY_NAME, - AINODE_CONF_FILE_NAME, - AINODE_MODELS_DIR, AINODE_LOG_DIR, AINODE_SYSTEM_DIR, AINODE_INFERENCE_RPC_ADDRESS, - AINODE_INFERENCE_RPC_PORT, AINODE_THRIFT_COMPRESSION_ENABLED, - AINODE_SYSTEM_FILE_NAME, AINODE_CLUSTER_NAME, AINODE_VERSION_INFO, AINODE_BUILD_INFO, - AINODE_CONF_GIT_FILE_NAME, AINODE_CONF_POM_FILE_NAME, AINODE_ROOT_DIR, - AINODE_ROOT_CONF_DIRECTORY_NAME) +from ainode.core.constant import ( + AINODE_BUILD_INFO, + AINODE_CLUSTER_NAME, + AINODE_CONF_DIRECTORY_NAME, + AINODE_CONF_FILE_NAME, + AINODE_CONF_GIT_FILE_NAME, + AINODE_CONF_POM_FILE_NAME, + AINODE_INFERENCE_RPC_ADDRESS, + AINODE_INFERENCE_RPC_PORT, + AINODE_LOG_DIR, + AINODE_MODELS_DIR, + AINODE_ROOT_CONF_DIRECTORY_NAME, + AINODE_ROOT_DIR, + AINODE_SYSTEM_DIR, + AINODE_SYSTEM_FILE_NAME, + AINODE_THRIFT_COMPRESSION_ENABLED, + AINODE_VERSION_INFO, +) from ainode.core.exception import BadNodeUrlError from ainode.core.log import Logger from ainode.core.util.decorator import singleton @@ -119,7 +130,9 @@ def set_ain_system_dir(self, ain_system_dir: str) -> None: def get_ain_thrift_compression_enabled(self) -> bool: return self._ain_thrift_compression_enabled - def set_ain_thrift_compression_enabled(self, ain_thrift_compression_enabled: int) -> None: + def set_ain_thrift_compression_enabled( + self, ain_thrift_compression_enabled: int + ) -> None: self._ain_thrift_compression_enabled = ain_thrift_compression_enabled def get_ain_model_storage_cache_size(self) -> int: @@ -129,7 +142,9 @@ def get_ain_target_config_node_list(self) -> TEndPoint: return self._ain_target_config_node_list def set_ain_target_config_node_list(self, ain_target_config_node_list: str) -> None: - self._ain_target_config_node_list = parse_endpoint_url(ain_target_config_node_list) + self._ain_target_config_node_list = parse_endpoint_url( + ain_target_config_node_list + ) @singleton @@ -141,31 +156,41 @@ def __init__(self): logger.info("AINodeDescriptor is init successfully.") def _load_config_from_file(self) -> None: - system_properties_file = os.path.join(self._config.get_ain_system_dir(), AINODE_SYSTEM_FILE_NAME) + system_properties_file = os.path.join( + self._config.get_ain_system_dir(), AINODE_SYSTEM_FILE_NAME + ) if os.path.exists(system_properties_file): system_configs = load_properties(system_properties_file) - if 'ainode_id' in system_configs: - self._config.set_ainode_id(int(system_configs['ainode_id'])) + if "ainode_id" in system_configs: + self._config.set_ainode_id(int(system_configs["ainode_id"])) - git_file = os.path.join(AINODE_ROOT_DIR, AINODE_ROOT_CONF_DIRECTORY_NAME, AINODE_CONF_GIT_FILE_NAME) + git_file = os.path.join( + AINODE_ROOT_DIR, AINODE_ROOT_CONF_DIRECTORY_NAME, AINODE_CONF_GIT_FILE_NAME + ) if os.path.exists(git_file): git_configs = load_properties(git_file) - if 'git.commit.id.abbrev' in git_configs: - build_info = git_configs['git.commit.id.abbrev'] - if 'git.dirty' in git_configs: - if git_configs['git.dirty'] == "true": + if "git.commit.id.abbrev" in git_configs: + build_info = git_configs["git.commit.id.abbrev"] + if "git.dirty" in git_configs: + if git_configs["git.dirty"] == "true": build_info += "-dev" self._config.set_build_info(build_info) - pom_file = os.path.join(AINODE_ROOT_DIR, AINODE_ROOT_CONF_DIRECTORY_NAME, AINODE_CONF_POM_FILE_NAME) + pom_file = os.path.join( + AINODE_ROOT_DIR, AINODE_ROOT_CONF_DIRECTORY_NAME, AINODE_CONF_POM_FILE_NAME + ) if os.path.exists(pom_file): pom_configs = load_properties(pom_file) - if 'version' in pom_configs: - self._config.set_version_info(pom_configs['version']) + if "version" in pom_configs: + self._config.set_version_info(pom_configs["version"]) conf_file = os.path.join(AINODE_CONF_DIRECTORY_NAME, AINODE_CONF_FILE_NAME) if not os.path.exists(conf_file): - logger.info("Cannot find AINode config file '{}', use default configuration.".format(conf_file)) + logger.info( + "Cannot find AINode config file '{}', use default configuration.".format( + conf_file + ) + ) return # noinspection PyBroadException @@ -174,43 +199,57 @@ def _load_config_from_file(self) -> None: config_keys = file_configs.keys() - if 'ain_inference_rpc_address' in config_keys: - self._config.set_ain_inference_rpc_address(file_configs['ain_inference_rpc_address']) + if "ain_inference_rpc_address" in config_keys: + self._config.set_ain_inference_rpc_address( + file_configs["ain_inference_rpc_address"] + ) - if 'ain_inference_rpc_port' in config_keys: - self._config.set_ain_inference_rpc_port(int(file_configs['ain_inference_rpc_port'])) + if "ain_inference_rpc_port" in config_keys: + self._config.set_ain_inference_rpc_port( + int(file_configs["ain_inference_rpc_port"]) + ) - if 'ain_models_dir' in config_keys: - self._config.set_ain_models_dir(file_configs['ain_models_dir']) + if "ain_models_dir" in config_keys: + self._config.set_ain_models_dir(file_configs["ain_models_dir"]) - if 'ain_system_dir' in config_keys: - self._config.set_ain_system_dir(file_configs['ain_system_dir']) + if "ain_system_dir" in config_keys: + self._config.set_ain_system_dir(file_configs["ain_system_dir"]) - if 'ain_seed_config_node' in config_keys: - self._config.set_ain_target_config_node_list(file_configs['ain_seed_config_node']) + if "ain_seed_config_node" in config_keys: + self._config.set_ain_target_config_node_list( + file_configs["ain_seed_config_node"] + ) - if 'cluster_name' in config_keys: - self._config.set_cluster_name(file_configs['cluster_name']) + if "cluster_name" in config_keys: + self._config.set_cluster_name(file_configs["cluster_name"]) - if 'ain_thrift_compression_enabled' in config_keys: - self._config.set_ain_thrift_compression_enabled(int(file_configs['ain_thrift_compression_enabled'])) + if "ain_thrift_compression_enabled" in config_keys: + self._config.set_ain_thrift_compression_enabled( + int(file_configs["ain_thrift_compression_enabled"]) + ) - if 'ain_logs_dir' in config_keys: - log_dir = file_configs['ain_logs_dir'] + if "ain_logs_dir" in config_keys: + log_dir = file_configs["ain_logs_dir"] self._config.set_ain_logs_dir(log_dir) - Logger(log_dir=log_dir).info(f"Successfully load config from {conf_file}.") + Logger(log_dir=log_dir).info( + f"Successfully load config from {conf_file}." + ) except BadNodeUrlError: logger.warning("Cannot load AINode conf file, use default configuration.") except Exception as e: - logger.warning("Cannot load AINode conf file caused by: {}, use default configuration. ".format(e)) + logger.warning( + "Cannot load AINode conf file caused by: {}, use default configuration. ".format( + e + ) + ) def get_config(self) -> AINodeConfig: return self._config -def load_properties(filepath, sep='=', comment_char='#'): +def load_properties(filepath, sep="=", comment_char="#"): """ Read the file passed as parameter as a properties file. """ @@ -227,7 +266,7 @@ def load_properties(filepath, sep='=', comment_char='#'): def parse_endpoint_url(endpoint_url: str) -> TEndPoint: - """ Parse TEndPoint from a given endpoint url. + """Parse TEndPoint from a given endpoint url. Args: endpoint_url: an endpoint url, format: ip:port Returns: diff --git a/iotdb-core/ainode/ainode/core/constant.py b/iotdb-core/ainode/ainode/core/constant.py index 1078836ff6520..24d13a12ab8a9 100644 --- a/iotdb-core/ainode/ainode/core/constant.py +++ b/iotdb-core/ainode/ainode/core/constant.py @@ -38,18 +38,18 @@ AINODE_CLUSTER_NAME = "defaultCluster" AINODE_VERSION_INFO = "UNKNOWN" AINODE_BUILD_INFO = "UNKNOWN" -AINODE_ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))) +AINODE_ROOT_DIR = os.path.dirname( + os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) +) # AINode log -AINODE_LOG_FILE_NAMES = ['log_ainode_all.log', - 'log_ainode_info.log', - 'log_ainode_warning.log', - 'log_ainode_error.log'] -AINODE_LOG_FILE_LEVELS = [ - logging.DEBUG, - logging.INFO, - logging.WARNING, - logging.ERROR] +AINODE_LOG_FILE_NAMES = [ + "log_ainode_all.log", + "log_ainode_info.log", + "log_ainode_warning.log", + "log_ainode_error.log", +] +AINODE_LOG_FILE_LEVELS = [logging.DEBUG, logging.INFO, logging.WARNING, logging.ERROR] TRIAL_ID_PREFIX = "__trial_" DEFAULT_TRIAL_ID = TRIAL_ID_PREFIX + "0" @@ -148,7 +148,7 @@ class BuiltInModelType(Enum): GAUSSIAN_HMM = "_gaussianhmm" GMM_HMM = "_gmmhmm" STRAY = "_stray" - + # timerxl TIMER_XL = "_timerxl" @@ -168,25 +168,25 @@ class AttributeName(Enum): PREDICT_LENGTH = "predict_length" # NaiveForecaster - STRATEGY = 'strategy' - SP = 'sp' + STRATEGY = "strategy" + SP = "sp" # STLForecaster # SP = 'sp' - SEASONAL = 'seasonal' - SEASONAL_DEG = 'seasonal_deg' - TREND_DEG = 'trend_deg' - LOW_PASS_DEG = 'low_pass_deg' - SEASONAL_JUMP = 'seasonal_jump' - TREND_JUMP = 'trend_jump' - LOSS_PASS_JUMP = 'low_pass_jump' + SEASONAL = "seasonal" + SEASONAL_DEG = "seasonal_deg" + TREND_DEG = "trend_deg" + LOW_PASS_DEG = "low_pass_deg" + SEASONAL_JUMP = "seasonal_jump" + TREND_JUMP = "trend_jump" + LOSS_PASS_JUMP = "low_pass_jump" # ExponentialSmoothing - DAMPED_TREND = 'damped_trend' - INITIALIZATION_METHOD = 'initialization_method' - OPTIMIZED = 'optimized' - REMOVE_BIAS = 'remove_bias' - USE_BRUTE = 'use_brute' + DAMPED_TREND = "damped_trend" + INITIALIZATION_METHOD = "initialization_method" + OPTIMIZED = "optimized" + REMOVE_BIAS = "remove_bias" + USE_BRUTE = "use_brute" # Arima ORDER = "order" @@ -248,7 +248,7 @@ class AttributeName(Enum): P = "p" SIZE_THRESHOLD = "size_threshold" OUTLIER_TAIL = "outlier_tail" - + # timerxl INPUT_TOKEN_LEN = "input_token_len" HIDDEN_SIZE = "hidden_size" diff --git a/iotdb-core/ainode/ainode/core/exception.py b/iotdb-core/ainode/ainode/core/exception.py index a9b8c496d65aa..977b10cfa04fa 100644 --- a/iotdb-core/ainode/ainode/core/exception.py +++ b/iotdb-core/ainode/ainode/core/exception.py @@ -17,7 +17,7 @@ # import re -from ainode.core.constant import DEFAULT_MODEL_FILE_NAME, DEFAULT_CONFIG_FILE_NAME +from ainode.core.constant import DEFAULT_CONFIG_FILE_NAME, DEFAULT_MODEL_FILE_NAME class _BaseError(Exception): @@ -41,8 +41,10 @@ def __init__(self, file_path: str): class BadConfigValueError(_BaseError): - def __init__(self, config_name: str, config_value, hint: str = ''): - self.message = "Bad value [{0}] for config {1}. {2}".format(config_value, config_name, hint) + def __init__(self, config_name: str, config_value, hint: str = ""): + self.message = "Bad value [{0}] for config {1}. {2}".format( + config_value, config_name, hint + ) class MissingConfigError(_BaseError): @@ -62,7 +64,9 @@ def __init__(self, option_name: str): class WrongTypeConfigError(_BaseError): def __init__(self, config_name: str, expected_type: str): - self.message = "Wrong type for config: {0}, expected: {1}".format(config_name, expected_type) + self.message = "Wrong type for config: {0}, expected: {1}".format( + config_name, expected_type + ) class UnsupportedError(_BaseError): @@ -72,16 +76,13 @@ def __init__(self, msg: str): class InvalidUriError(_BaseError): def __init__(self, uri: str): - self.message = "Invalid uri: {}, there are no {} or {} under this uri.".format(uri, DEFAULT_MODEL_FILE_NAME, - DEFAULT_CONFIG_FILE_NAME) + self.message = "Invalid uri: {}, there are no {} or {} under this uri.".format( + uri, DEFAULT_MODEL_FILE_NAME, DEFAULT_CONFIG_FILE_NAME + ) class InvalidWindowArgumentError(_BaseError): - def __init__( - self, - window_interval, - window_step, - dataset_length): + def __init__(self, window_interval, window_step, dataset_length): self.message = f"Invalid inference input: window_interval {window_interval}, window_step {window_step}, dataset_length {dataset_length}" @@ -97,30 +98,41 @@ def __init__(self, msg: str): class WrongAttributeTypeError(_BaseError): def __init__(self, attribute_name: str, expected_type: str): - self.message = "Wrong type for attribute: {0}, expected: {1}".format(attribute_name, expected_type) + self.message = "Wrong type for attribute: {0}, expected: {1}".format( + attribute_name, expected_type + ) class NumericalRangeException(_BaseError): def __init__(self, attribute_name: str, value, min_value, max_value): - self.message = "Attribute {0} expect value between {1} and {2}, got {3} instead." \ - .format(attribute_name, min_value, max_value, value) + self.message = ( + "Attribute {0} expect value between {1} and {2}, got {3} instead.".format( + attribute_name, min_value, max_value, value + ) + ) class StringRangeException(_BaseError): def __init__(self, attribute_name: str, value: str, expect_value): - self.message = "Attribute {0} expect value in {1}, got {2} instead." \ - .format(attribute_name, expect_value, value) + self.message = "Attribute {0} expect value in {1}, got {2} instead.".format( + attribute_name, expect_value, value + ) class ListRangeException(_BaseError): def __init__(self, attribute_name: str, value: list, expected_type: str): - self.message = "Attribute {0} expect value type list[{1}], got {2} instead." \ - .format(attribute_name, expected_type, value) + self.message = ( + "Attribute {0} expect value type list[{1}], got {2} instead.".format( + attribute_name, expected_type, value + ) + ) class AttributeNotSupportError(_BaseError): def __init__(self, model_name: str, attribute_name: str): - self.message = "Attribute {0} is not supported in model {1}".format(attribute_name, model_name) + self.message = "Attribute {0} is not supported in model {1}".format( + attribute_name, model_name + ) # This is used to extract the key message in RuntimeError instead of the traceback message diff --git a/iotdb-core/ainode/ainode/core/handler.py b/iotdb-core/ainode/ainode/core/handler.py index fc8f8f1aae74f..456bc97269a3d 100644 --- a/iotdb-core/ainode/ainode/core/handler.py +++ b/iotdb-core/ainode/ainode/core/handler.py @@ -20,9 +20,17 @@ from ainode.core.manager.inference_manager import InferenceManager from ainode.core.manager.model_manager import ModelManager from ainode.thrift.ainode import IAINodeRPCService -from ainode.thrift.ainode.ttypes import (TDeleteModelReq, TRegisterModelReq, - TAIHeartbeatReq, TInferenceReq, TRegisterModelResp, TInferenceResp, - TAIHeartbeatResp, TTrainingReq, TForecastReq) +from ainode.thrift.ainode.ttypes import ( + TAIHeartbeatReq, + TAIHeartbeatResp, + TDeleteModelReq, + TForecastReq, + TInferenceReq, + TInferenceResp, + TRegisterModelReq, + TRegisterModelResp, + TTrainingReq, +) from ainode.thrift.common.ttypes import TSStatus diff --git a/iotdb-core/ainode/ainode/core/log.py b/iotdb-core/ainode/ainode/core/log.py index 4b2f412eaafb2..f1fd470a235d0 100644 --- a/iotdb-core/ainode/ainode/core/log.py +++ b/iotdb-core/ainode/ainode/core/log.py @@ -23,7 +23,11 @@ import sys import threading -from ainode.core.constant import STD_LEVEL, AINODE_LOG_FILE_NAMES, AINODE_LOG_FILE_LEVELS +from ainode.core.constant import ( + AINODE_LOG_FILE_LEVELS, + AINODE_LOG_FILE_NAMES, + STD_LEVEL, +) from ainode.core.util.decorator import singleton @@ -46,7 +50,9 @@ def custom_log_info(): # if file_name is not in current working directory, find the first "iotdb" in the path for l in range(len(file_name)): i = len(file_name) - l - 1 - if file_name[i:].startswith("iotdb/") or file_name[i:].startswith("iotdb\\"): + if file_name[i:].startswith("iotdb/") or file_name[i:].startswith( + "iotdb\\" + ): file_name = file_name[i:] break @@ -57,7 +63,7 @@ def custom_log_info(): @singleton class Logger: - """ Logger is a singleton, it will be initialized when AINodeDescriptor is inited for the first time. + """Logger is a singleton, it will be initialized when AINodeDescriptor is inited for the first time. You can just use Logger() to get it anywhere. Args: @@ -72,9 +78,9 @@ class Logger: def __init__(self, log_dir=None): - self.logger_format = logging.Formatter(fmt='%(asctime)s %(levelname)s %(' - 'message)s', - datefmt='%Y-%m-%d %H:%M:%S') + self.logger_format = logging.Formatter( + fmt="%(asctime)s %(levelname)s %(" "message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) self.logger = logging.getLogger(str(random.random())) self.logger.handlers.clear() @@ -94,12 +100,14 @@ def __init__(self, log_dir=None): for file_name in file_names: log_path = log_dir + "/" + file_name if not os.path.exists(log_path): - f = open(log_path, mode='w', encoding='utf-8') + f = open(log_path, mode="w", encoding="utf-8") f.close() os.chmod(log_path, 0o777) self.file_handlers = [] for l in range(len(file_names)): - self.file_handlers.append(logging.FileHandler(log_dir + "/" + file_names[l], mode='a')) + self.file_handlers.append( + logging.FileHandler(log_dir + "/" + file_names[l], mode="a") + ) self.file_handlers[l].setLevel(file_levels[l]) self.file_handlers[l].setFormatter(self.logger_format) @@ -114,20 +122,20 @@ def __init__(self, log_dir=None): def debug(self, *args) -> None: self._lock.acquire() - self.logger.debug(' '.join(map(str, args))) + self.logger.debug(" ".join(map(str, args))) self._lock.release() def info(self, *args) -> None: self._lock.acquire() - self.logger.info(' '.join(map(str, args))) + self.logger.info(" ".join(map(str, args))) self._lock.release() def warning(self, *args) -> None: self._lock.acquire() - self.logger.warning(' '.join(map(str, args))) + self.logger.warning(" ".join(map(str, args))) self._lock.release() def error(self, *args) -> None: self._lock.acquire() - self.logger.error(' '.join(map(str, args))) + self.logger.error(" ".join(map(str, args))) self._lock.release() diff --git a/iotdb-core/ainode/ainode/core/manager/cluster_manager.py b/iotdb-core/ainode/ainode/core/manager/cluster_manager.py index da7008b776259..b7d65f47dc19d 100644 --- a/iotdb-core/ainode/ainode/core/manager/cluster_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/cluster_manager.py @@ -17,7 +17,7 @@ # import psutil -from ainode.thrift.ainode.ttypes import TAIHeartbeatResp, TAIHeartbeatReq +from ainode.thrift.ainode.ttypes import TAIHeartbeatReq, TAIHeartbeatResp from ainode.thrift.common.ttypes import TLoadSample @@ -27,15 +27,20 @@ def get_heart_beat(req: TAIHeartbeatReq) -> TAIHeartbeatResp: if req.needSamplingLoad: cpu_percent = psutil.cpu_percent(interval=1) memory_percent = psutil.virtual_memory().percent - disk_usage = psutil.disk_usage('/') + disk_usage = psutil.disk_usage("/") disk_free = disk_usage.free - load_sample = TLoadSample(cpuUsageRate=cpu_percent, - memoryUsageRate=memory_percent, - diskUsageRate=disk_usage.percent, - freeDiskSpace=disk_free / 1024 / 1024 / 1024) - return TAIHeartbeatResp(heartbeatTimestamp=req.heartbeatTimestamp, - status="Running", - loadSample=load_sample) + load_sample = TLoadSample( + cpuUsageRate=cpu_percent, + memoryUsageRate=memory_percent, + diskUsageRate=disk_usage.percent, + freeDiskSpace=disk_free / 1024 / 1024 / 1024, + ) + return TAIHeartbeatResp( + heartbeatTimestamp=req.heartbeatTimestamp, + status="Running", + loadSample=load_sample, + ) else: - return TAIHeartbeatResp(heartbeatTimestamp=req.heartbeatTimestamp, - status="Running") + return TAIHeartbeatResp( + heartbeatTimestamp=req.heartbeatTimestamp, status="Running" + ) diff --git a/iotdb-core/ainode/ainode/core/manager/inference_manager.py b/iotdb-core/ainode/ainode/core/manager/inference_manager.py index 9df60fe1a400b..175961e731344 100644 --- a/iotdb-core/ainode/ainode/core/manager/inference_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/inference_manager.py @@ -15,19 +15,28 @@ # specific language governing permissions and limitations # under the License. # -from abc import abstractmethod, ABC +from abc import ABC, abstractmethod import pandas as pd import torch from iotdb.tsfile.utils.tsblock_serde import deserialize from ainode.core.constant import TSStatusCode -from ainode.core.exception import InvalidWindowArgumentError, InferenceModelInternalError, runtime_error_extractor +from ainode.core.exception import ( + InferenceModelInternalError, + InvalidWindowArgumentError, + runtime_error_extractor, +) from ainode.core.log import Logger from ainode.core.manager.model_manager import ModelManager from ainode.core.util.serde import convert_to_binary from ainode.core.util.status import get_status -from ainode.thrift.ainode.ttypes import TInferenceReq, TInferenceResp, TForecastReq, TForecastResp +from ainode.thrift.ainode.ttypes import ( + TForecastReq, + TForecastResp, + TInferenceReq, + TInferenceResp, +) logger = Logger() @@ -46,20 +55,23 @@ def infer(self, full_data, **kwargs): class TimerXLStrategy(InferenceStrategy): def infer(self, full_data, predict_length=96, **_): data = full_data[1][0] - if data.dtype.byteorder not in ('=', '|'): + if data.dtype.byteorder not in ("=", "|"): data = data.byteswap().newbyteorder() output = self.model.inference(data, int(predict_length)) df = pd.DataFrame(output[0]) return convert_to_binary(df) + class SundialStrategy(InferenceStrategy): def infer(self, full_data, predict_length=96, **_): data = full_data[1][0] - if data.dtype.byteorder not in ('=', '|'): + if data.dtype.byteorder not in ("=", "|"): data = data.byteswap().newbyteorder() seqs = torch.tensor(data).unsqueeze(0).float() # TODO: unify model inference input - output = self.model.generate(seqs, max_new_tokens=predict_length, num_samples=10, revin=True) + output = self.model.generate( + seqs, max_new_tokens=predict_length, num_samples=10, revin=True + ) df = pd.DataFrame(output[0].mean(dim=0)) return convert_to_binary(df) @@ -77,7 +89,7 @@ def infer(self, full_data, window_interval=None, window_step=None, **kwargs): _, dataset, _, length = full_data if window_interval is None or window_step is None: window_interval = length - window_step = float('inf') + window_step = float("inf") if window_interval <= 0 or window_step <= 0 or window_interval > length: raise InvalidWindowArgumentError(window_interval, window_step, length) @@ -88,7 +100,7 @@ def infer(self, full_data, window_interval=None, window_step=None, **kwargs): results = [] try: for i in range(times): - start = 0 if window_step == float('inf') else i * window_step + start = 0 if window_step == float("inf") else i * window_step end = start + window_interval window = data[:, start:end, :] out = self.model(window) @@ -103,11 +115,11 @@ def infer(self, full_data, window_interval=None, window_step=None, **kwargs): def _get_strategy(model_id, model): - if model_id == '_timerxl': + if model_id == "_timerxl": return TimerXLStrategy(model) - if model_id == '_sundial': + if model_id == "_sundial": return SundialStrategy(model) - if model_id.startswith('_'): + if model_id.startswith("_"): return BuiltInStrategy(model) return RegisteredStrategy(model) @@ -117,7 +129,15 @@ class InferenceManager: def __init__(self, model_manager: ModelManager): self.model_manager = model_manager - def _run(self, req, data_getter, deserializer, extract_attrs, resp_cls, single_output: bool): + def _run( + self, + req, + data_getter, + deserializer, + extract_attrs, + resp_cls, + single_output: bool, + ): model_id = req.modelId logger.info(f"Start processing for {model_id}") try: @@ -126,10 +146,10 @@ def _run(self, req, data_getter, deserializer, extract_attrs, resp_cls, single_o attrs = extract_attrs(req) # load model - if model_id.startswith('_'): + if model_id.startswith("_"): model = self.model_manager.load_built_in_model(model_id, attrs) else: - accel = str(attrs.get('acceleration', '')).lower() == 'true' + accel = str(attrs.get("acceleration", "")).lower() == "true" model = self.model_manager.load_model(model_id, accel) # inference by strategy @@ -146,7 +166,7 @@ def _run(self, req, data_getter, deserializer, extract_attrs, resp_cls, single_o except Exception as e: logger.error(e) status = get_status(TSStatusCode.AINODE_INTERNAL_ERROR, str(e)) - empty = b'' if single_output else [] + empty = b"" if single_output else [] return resp_cls(status, empty) def forecast(self, req: TForecastReq): @@ -154,9 +174,12 @@ def forecast(self, req: TForecastReq): req, data_getter=lambda r: r.inputData, deserializer=deserialize, - extract_attrs=lambda r: {'predict_length': r.outputLength, **(r.options or {})}, + extract_attrs=lambda r: { + "predict_length": r.outputLength, + **(r.options or {}), + }, resp_cls=TForecastResp, - single_output=True + single_output=True, ) def inference(self, req: TInferenceReq): @@ -165,10 +188,10 @@ def inference(self, req: TInferenceReq): data_getter=lambda r: r.dataset, deserializer=deserialize, extract_attrs=lambda r: { - 'window_interval': getattr(r.windowParams, 'windowInterval', None), - 'window_step': getattr(r.windowParams, 'windowStep', None), - **(r.inferenceAttributes or {}) + "window_interval": getattr(r.windowParams, "windowInterval", None), + "window_step": getattr(r.windowParams, "windowStep", None), + **(r.inferenceAttributes or {}), }, resp_cls=TInferenceResp, - single_output=False + single_output=False, ) diff --git a/iotdb-core/ainode/ainode/core/manager/model_manager.py b/iotdb-core/ainode/ainode/core/manager/model_manager.py index ead833a5839a3..95fdda1456b14 100644 --- a/iotdb-core/ainode/ainode/core/manager/model_manager.py +++ b/iotdb-core/ainode/ainode/core/manager/model_manager.py @@ -19,13 +19,21 @@ from yaml import YAMLError -from ainode.core.constant import TSStatusCode, BuiltInModelType -from ainode.core.exception import InvalidUriError, BadConfigValueError, BuiltInModelNotSupportError +from ainode.core.constant import BuiltInModelType, TSStatusCode +from ainode.core.exception import ( + BadConfigValueError, + BuiltInModelNotSupportError, + InvalidUriError, +) from ainode.core.log import Logger from ainode.core.model.built_in_model_factory import fetch_built_in_model from ainode.core.model.model_storage import ModelStorage from ainode.core.util.status import get_status -from ainode.thrift.ainode.ttypes import TRegisterModelReq, TRegisterModelResp, TDeleteModelReq +from ainode.thrift.ainode.ttypes import ( + TDeleteModelReq, + TRegisterModelReq, + TRegisterModelResp, +) from ainode.thrift.common.ttypes import TSStatus logger = Logger() @@ -38,26 +46,42 @@ def __init__(self): def register_model(self, req: TRegisterModelReq) -> TRegisterModelResp: logger.info(f"register model {req.modelId} from {req.uri}") try: - configs, attributes = self.model_storage.register_model(req.modelId, req.uri) - return TRegisterModelResp(get_status(TSStatusCode.SUCCESS_STATUS), configs, attributes) + configs, attributes = self.model_storage.register_model( + req.modelId, req.uri + ) + return TRegisterModelResp( + get_status(TSStatusCode.SUCCESS_STATUS), configs, attributes + ) except InvalidUriError as e: logger.warning(e) self.model_storage.delete_model(req.modelId) - return TRegisterModelResp(get_status(TSStatusCode.INVALID_URI_ERROR, e.message)) + return TRegisterModelResp( + get_status(TSStatusCode.INVALID_URI_ERROR, e.message) + ) except BadConfigValueError as e: logger.warning(e) self.model_storage.delete_model(req.modelId) - return TRegisterModelResp(get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message)) + return TRegisterModelResp( + get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, e.message) + ) except YAMLError as e: logger.warning(e) self.model_storage.delete_model(req.modelId) - if hasattr(e, 'problem_mark'): + if hasattr(e, "problem_mark"): mark = e.problem_mark - return TRegisterModelResp(get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, - f"An error occurred while parsing the yaml file, " - f"at line {mark.line + 1} column {mark.column + 1}.")) + return TRegisterModelResp( + get_status( + TSStatusCode.INVALID_INFERENCE_CONFIG, + f"An error occurred while parsing the yaml file, " + f"at line {mark.line + 1} column {mark.column + 1}.", + ) + ) return TRegisterModelResp( - get_status(TSStatusCode.INVALID_INFERENCE_CONFIG, f"An error occurred while parsing the yaml file")) + get_status( + TSStatusCode.INVALID_INFERENCE_CONFIG, + f"An error occurred while parsing the yaml file", + ) + ) except Exception as e: logger.warning(e) self.model_storage.delete_model(req.modelId) diff --git a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py index f699934e0c5e8..6298fb6a1db36 100644 --- a/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py +++ b/iotdb-core/ainode/ainode/core/model/built_in_model_factory.py @@ -17,27 +17,32 @@ # import os from abc import abstractmethod -from typing import List, Dict +from typing import Dict, List import numpy as np from sklearn.preprocessing import MinMaxScaler -from sktime.annotation.hmm_learn import GaussianHMM, GMMHMM +from sktime.annotation.hmm_learn import GMMHMM, GaussianHMM from sktime.annotation.stray import STRAY from sktime.forecasting.arima import ARIMA from sktime.forecasting.exp_smoothing import ExponentialSmoothing from sktime.forecasting.naive import NaiveForecaster from sktime.forecasting.trend import STLForecaster -from ainode.TimerXL.models import timer_xl -from ainode.TimerXL.models.configuration_timer import TimerxlConfig -from ainode.core.model.sundial import modeling_sundial from ainode.core.config import AINodeDescriptor from ainode.core.constant import AttributeName, BuiltInModelType -from ainode.core.exception import InferenceModelInternalError -from ainode.core.exception import WrongAttributeTypeError, NumericalRangeException, StringRangeException, \ - ListRangeException, BuiltInModelNotSupportError +from ainode.core.exception import ( + BuiltInModelNotSupportError, + InferenceModelInternalError, + ListRangeException, + NumericalRangeException, + StringRangeException, + WrongAttributeTypeError, +) from ainode.core.log import Logger +from ainode.core.model.sundial import modeling_sundial from ainode.core.model.sundial.configuration_sundial import SundialConfig +from ainode.TimerXL.models import timer_xl +from ainode.TimerXL.models.configuration_timer import TimerxlConfig logger = Logger() @@ -47,7 +52,10 @@ def get_model_attributes(model_id: str): attribute_map = arima_attribute_map elif model_id == BuiltInModelType.NAIVE_FORECASTER.value: attribute_map = naive_forecaster_attribute_map - elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value or model_id == BuiltInModelType.HOLTWINTERS.value: + elif ( + model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value + or model_id == BuiltInModelType.HOLTWINTERS.value + ): attribute_map = exponential_smoothing_attribute_map elif model_id == BuiltInModelType.STL_FORECASTER.value: attribute_map = stl_forecaster_attribute_map @@ -89,7 +97,10 @@ def fetch_built_in_model(model_id, inference_attributes): # build the built-in model if model_id == BuiltInModelType.ARIMA.value: model = ArimaModel(attributes) - elif model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value or model_id == BuiltInModelType.HOLTWINTERS.value: + elif ( + model_id == BuiltInModelType.EXPONENTIAL_SMOOTHING.value + or model_id == BuiltInModelType.HOLTWINTERS.value + ): model = ExponentialSmoothingModel(attributes) elif model_id == BuiltInModelType.NAIVE_FORECASTER.value: model = NaiveForecasterModel(attributes) @@ -104,7 +115,9 @@ def fetch_built_in_model(model_id, inference_attributes): elif model_id == BuiltInModelType.TIMER_XL.value: model = timer_xl.Model(TimerxlConfig.from_dict(attributes)) elif model_id == BuiltInModelType.SUNDIAL.value: - model = modeling_sundial.SundialForPrediction(SundialConfig.from_dict(attributes)) + model = modeling_sundial.SundialForPrediction( + SundialConfig.from_dict(attributes) + ) else: raise BuiltInModelNotSupportError(model_id) @@ -133,11 +146,13 @@ def parse(self, string_value: str): class IntAttribute(Attribute): - def __init__(self, name: str, - default_value: int, - default_low: int, - default_high: int, - ): + def __init__( + self, + name: str, + default_value: int, + default_low: int, + default_high: int, + ): super(IntAttribute, self).__init__(name) self.__default_value = default_value self.__default_low = default_low @@ -149,7 +164,9 @@ def get_default_value(self): def validate_value(self, value): if self.__default_low <= value <= self.__default_high: return True - raise NumericalRangeException(self._name, value, self.__default_low, self.__default_high) + raise NumericalRangeException( + self._name, value, self.__default_low, self.__default_high + ) def parse(self, string_value: str): try: @@ -160,11 +177,13 @@ def parse(self, string_value: str): class FloatAttribute(Attribute): - def __init__(self, name: str, - default_value: float, - default_low: float, - default_high: float, - ): + def __init__( + self, + name: str, + default_value: float, + default_low: float, + default_high: float, + ): super(FloatAttribute, self).__init__(name) self.__default_value = default_value self.__default_low = default_low @@ -176,7 +195,9 @@ def get_default_value(self): def validate_value(self, value): if self.__default_low <= value <= self.__default_high: return True - raise NumericalRangeException(self._name, value, self.__default_low, self.__default_high) + raise NumericalRangeException( + self._name, value, self.__default_low, self.__default_high + ) def parse(self, string_value: str): try: @@ -258,7 +279,9 @@ def parse(self, string_value: str): try: list_value[i] = self.__value_type(list_value[i]) except: - raise ListRangeException(self._name, list_value, self.__type_to_str[self.__value_type]) + raise ListRangeException( + self._name, list_value, self.__type_to_str[self.__value_type] + ) return list_value @@ -295,12 +318,16 @@ def parse(self, string_value: str): try: list_value[i] = self.__value_type(list_value[i]) except: - raise ListRangeException(self._name, list_value, self.__type_to_str[self.__value_type]) + raise ListRangeException( + self._name, list_value, self.__type_to_str[self.__value_type] + ) tuple_value = tuple(list_value) return tuple_value -def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute]): +def parse_attribute( + input_attributes: Dict[str, str], attribute_map: Dict[str, Attribute] +): """ Args: input_attributes: a dict of attributes, where the key is the attribute name, the value is the string value of @@ -321,47 +348,48 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A # user did not specify the attribute, use the default value else: try: - attributes[attribute_name] = attribute_map[attribute_name].get_default_value() + attributes[attribute_name] = attribute_map[ + attribute_name + ].get_default_value() except NotImplementedError as e: logger.error(f"attribute {attribute_name} is not implemented.") raise e return attributes + sundial_attribute_map = { AttributeName.INPUT_TOKEN_LEN.value: IntAttribute( name=AttributeName.INPUT_TOKEN_LEN.value, default_value=16, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.HIDDEN_SIZE.value: IntAttribute( name=AttributeName.HIDDEN_SIZE.value, default_value=768, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( name=AttributeName.INTERMEDIATE_SIZE.value, default_value=3072, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, - default_value=[720], - value_type=int + name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[720], value_type=int ), AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( name=AttributeName.NUM_HIDDEN_LAYERS.value, default_value=12, default_low=1, - default_high=16 + default_high=16, ), AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( name=AttributeName.NUM_ATTENTION_HEADS.value, default_value=12, default_low=1, - default_high=192 + default_high=192, ), AttributeName.HIDDEN_ACT.value: StringAttribute( name=AttributeName.HIDDEN_ACT.value, @@ -376,50 +404,54 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.ROPE_THETA.value, default_value=10000, default_low=1000, - default_high=50000 + default_high=50000, ), AttributeName.DROPOUT_RATE.value: FloatAttribute( name=AttributeName.DROPOUT_RATE.value, default_value=0.1, default_low=0.0, - default_high=1.0 + default_high=1.0, ), AttributeName.INITIALIZER_RANGE.value: FloatAttribute( name=AttributeName.INITIALIZER_RANGE.value, default_value=0.02, default_low=0.0, - default_high=1.0 + default_high=1.0, ), AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( name=AttributeName.MAX_POSITION_EMBEDDINGS.value, default_value=10000, default_low=1, - default_high=50000 + default_high=50000, ), AttributeName.FLOW_LOSS_DEPTH.value: IntAttribute( name=AttributeName.FLOW_LOSS_DEPTH.value, default_value=3, default_low=1, - default_high=50 + default_high=50, ), AttributeName.NUM_SAMPLING_STEPS.value: IntAttribute( name=AttributeName.NUM_SAMPLING_STEPS.value, default_value=50, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.DIFFUSION_BATCH_MUL.value: IntAttribute( name=AttributeName.DIFFUSION_BATCH_MUL.value, default_value=4, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.CKPT_PATH.value: StringAttribute( name=AttributeName.CKPT_PATH.value, - default_value=os.path.join(os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir(), 'weights', - 'sundial'), - value_choices=[''] - ) + default_value=os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + "weights", + "sundial", + ), + value_choices=[""], + ), } timerxl_attribute_map = { @@ -427,36 +459,34 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.INPUT_TOKEN_LEN.value, default_value=96, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.HIDDEN_SIZE.value: IntAttribute( name=AttributeName.HIDDEN_SIZE.value, default_value=1024, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.INTERMEDIATE_SIZE.value: IntAttribute( name=AttributeName.INTERMEDIATE_SIZE.value, default_value=2048, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.OUTPUT_TOKEN_LENS.value: ListAttribute( - name=AttributeName.OUTPUT_TOKEN_LENS.value, - default_value=[96], - value_type=int + name=AttributeName.OUTPUT_TOKEN_LENS.value, default_value=[96], value_type=int ), AttributeName.NUM_HIDDEN_LAYERS.value: IntAttribute( name=AttributeName.NUM_HIDDEN_LAYERS.value, default_value=8, default_low=1, - default_high=16 + default_high=16, ), AttributeName.NUM_ATTENTION_HEADS.value: IntAttribute( name=AttributeName.NUM_ATTENTION_HEADS.value, default_value=8, default_low=1, - default_high=192 + default_high=192, ), AttributeName.HIDDEN_ACT.value: StringAttribute( name=AttributeName.HIDDEN_ACT.value, @@ -471,32 +501,37 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.ROPE_THETA.value, default_value=10000, default_low=1000, - default_high=50000 + default_high=50000, ), AttributeName.ATTENTION_DROPOUT.value: FloatAttribute( name=AttributeName.ATTENTION_DROPOUT.value, default_value=0.0, default_low=0.0, - default_high=1.0 + default_high=1.0, ), AttributeName.INITIALIZER_RANGE.value: FloatAttribute( name=AttributeName.INITIALIZER_RANGE.value, default_value=0.02, default_low=0.0, - default_high=1.0 + default_high=1.0, ), AttributeName.MAX_POSITION_EMBEDDINGS.value: IntAttribute( name=AttributeName.MAX_POSITION_EMBEDDINGS.value, default_value=10000, default_low=1, - default_high=50000 + default_high=50000, ), AttributeName.CKPT_PATH.value: StringAttribute( name=AttributeName.CKPT_PATH.value, - default_value=os.path.join(os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir(), 'weights', - 'timerxl', 'model.safetensors'), - value_choices=[''] - ) + default_value=os.path.join( + os.getcwd(), + AINodeDescriptor().get_config().get_ain_models_dir(), + "weights", + "timerxl", + "model.safetensors", + ), + value_choices=[""], + ), } # built-in sktime model attributes @@ -506,7 +541,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.PREDICT_LENGTH.value, default_value=1, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.STRATEGY.value: StringAttribute( name=AttributeName.STRATEGY.value, @@ -514,10 +549,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A value_choices=["last", "mean"], ), AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, - default_value=1, - default_low=1, - default_high=5000 + name=AttributeName.SP.value, default_value=1, default_low=1, default_high=5000 ), } # ExponentialSmoothing @@ -526,7 +558,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.PREDICT_LENGTH.value, default_value=1, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.DAMPED_TREND.value: BooleanAttribute( name=AttributeName.DAMPED_TREND.value, @@ -548,7 +580,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A AttributeName.USE_BRUTE.value: BooleanAttribute( name=AttributeName.USE_BRUTE.value, default_value=False, - ) + ), } # Arima arima_attribute_map = { @@ -556,17 +588,15 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.PREDICT_LENGTH.value, default_value=1, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.ORDER.value: TupleAttribute( - name=AttributeName.ORDER.value, - default_value=(1, 0, 0), - value_type=int + name=AttributeName.ORDER.value, default_value=(1, 0, 0), value_type=int ), AttributeName.SEASONAL_ORDER.value: TupleAttribute( name=AttributeName.SEASONAL_ORDER.value, default_value=(0, 0, 0, 0), - value_type=int + value_type=int, ), AttributeName.METHOD.value: StringAttribute( name=AttributeName.METHOD.value, @@ -577,7 +607,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.MAXITER.value, default_value=1, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.SUPPRESS_WARNINGS.value: BooleanAttribute( name=AttributeName.SUPPRESS_WARNINGS.value, @@ -587,7 +617,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.OUT_OF_SAMPLE_SIZE.value, default_value=0, default_low=0, - default_high=5000 + default_high=5000, ), AttributeName.SCORING.value: StringAttribute( name=AttributeName.SCORING.value, @@ -629,7 +659,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A AttributeName.CONCENTRATE_SCALE.value: BooleanAttribute( name=AttributeName.CONCENTRATE_SCALE.value, default_value=False, - ) + ), } # STLForecaster stl_forecaster_attribute_map = { @@ -637,55 +667,52 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.PREDICT_LENGTH.value, default_value=1, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.SP.value: IntAttribute( - name=AttributeName.SP.value, - default_value=2, - default_low=1, - default_high=5000 + name=AttributeName.SP.value, default_value=2, default_low=1, default_high=5000 ), AttributeName.SEASONAL.value: IntAttribute( name=AttributeName.SEASONAL.value, default_value=7, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.SEASONAL_DEG.value: IntAttribute( name=AttributeName.SEASONAL_DEG.value, default_value=1, default_low=0, - default_high=5000 + default_high=5000, ), AttributeName.TREND_DEG.value: IntAttribute( name=AttributeName.TREND_DEG.value, default_value=1, default_low=0, - default_high=5000 + default_high=5000, ), AttributeName.LOW_PASS_DEG.value: IntAttribute( name=AttributeName.LOW_PASS_DEG.value, default_value=1, default_low=0, - default_high=5000 + default_high=5000, ), AttributeName.SEASONAL_JUMP.value: IntAttribute( name=AttributeName.SEASONAL_JUMP.value, default_value=1, default_low=0, - default_high=5000 + default_high=5000, ), AttributeName.TREND_JUMP.value: IntAttribute( name=AttributeName.TREND_JUMP.value, default_value=1, default_low=0, - default_high=5000 + default_high=5000, ), AttributeName.LOSS_PASS_JUMP.value: IntAttribute( name=AttributeName.LOSS_PASS_JUMP.value, default_value=1, default_low=0, - default_high=5000 + default_high=5000, ), } @@ -695,7 +722,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.N_COMPONENTS.value, default_value=1, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.COVARIANCE_TYPE.value: StringAttribute( name=AttributeName.COVARIANCE_TYPE.value, @@ -753,7 +780,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.N_ITER.value, default_value=10, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.TOL.value: FloatAttribute( name=AttributeName.TOL.value, @@ -775,7 +802,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.IMPLEMENTATION.value, default_value="log", value_choices=["log", "scaling"], - ) + ), } # GMMHMM @@ -784,13 +811,13 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.N_COMPONENTS.value, default_value=1, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.N_MIX.value: IntAttribute( name=AttributeName.N_MIX.value, default_value=1, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.MIN_COVAR.value: FloatAttribute( name=AttributeName.MIN_COVAR.value, @@ -842,7 +869,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.N_ITER.value, default_value=10, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.TOL.value: FloatAttribute( name=AttributeName.TOL.value, @@ -853,22 +880,82 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A AttributeName.INIT_PARAMS.value: StringAttribute( name=AttributeName.INIT_PARAMS.value, default_value="stmcw", - value_choices=["s", "t", "m", "c", "w", "st", "sm", "sc", "sw", "tm", "tc", "tw", "mc", "mw", "cw", "stm", - "stc", "stw", "smc", "smw", "scw", "tmc", "tmw", "tcw", "mcw", "stmc", "stmw", "stcw", "smcw", - "tmcw", "stmcw"] + value_choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], ), AttributeName.PARAMS.value: StringAttribute( name=AttributeName.PARAMS.value, default_value="stmcw", - value_choices=["s", "t", "m", "c", "w", "st", "sm", "sc", "sw", "tm", "tc", "tw", "mc", "mw", "cw", "stm", - "stc", "stw", "smc", "smw", "scw", "tmc", "tmw", "tcw", "mcw", "stmc", "stmw", "stcw", "smcw", - "tmcw", "stmcw"] + value_choices=[ + "s", + "t", + "m", + "c", + "w", + "st", + "sm", + "sc", + "sw", + "tm", + "tc", + "tw", + "mc", + "mw", + "cw", + "stm", + "stc", + "stw", + "smc", + "smw", + "scw", + "tmc", + "tmw", + "tcw", + "mcw", + "stmc", + "stmw", + "stcw", + "smcw", + "tmcw", + "stmcw", + ], ), AttributeName.IMPLEMENTATION.value: StringAttribute( name=AttributeName.IMPLEMENTATION.value, default_value="log", value_choices=["log", "scaling"], - ) + ), } # STRAY @@ -880,10 +967,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A default_high=1e10, ), AttributeName.K.value: IntAttribute( - name=AttributeName.K.value, - default_value=10, - default_low=1, - default_high=5000 + name=AttributeName.K.value, default_value=10, default_low=1, default_high=5000 ), AttributeName.KNN_ALGORITHM.value: StringAttribute( name=AttributeName.KNN_ALGORITHM.value, @@ -900,7 +984,7 @@ def parse_attribute(input_attributes: Dict[str, str], attribute_map: Dict[str, A name=AttributeName.SIZE_THRESHOLD.value, default_value=50, default_low=1, - default_high=5000 + default_high=5000, ), AttributeName.OUTLIER_TAIL.value: StringAttribute( name=AttributeName.OUTLIER_TAIL.value, @@ -924,27 +1008,27 @@ class ArimaModel(BuiltInModel): def __init__(self, attributes): super(ArimaModel, self).__init__(attributes) self._model = ARIMA( - order=attributes['order'], - seasonal_order=attributes['seasonal_order'], - method=attributes['method'], - suppress_warnings=attributes['suppress_warnings'], - maxiter=attributes['maxiter'], - out_of_sample_size=attributes['out_of_sample_size'], - scoring=attributes['scoring'], - with_intercept=attributes['with_intercept'], - time_varying_regression=attributes['time_varying_regression'], - enforce_stationarity=attributes['enforce_stationarity'], - enforce_invertibility=attributes['enforce_invertibility'], - simple_differencing=attributes['simple_differencing'], - measurement_error=attributes['measurement_error'], - mle_regression=attributes['mle_regression'], - hamilton_representation=attributes['hamilton_representation'], - concentrate_scale=attributes['concentrate_scale'] + order=attributes["order"], + seasonal_order=attributes["seasonal_order"], + method=attributes["method"], + suppress_warnings=attributes["suppress_warnings"], + maxiter=attributes["maxiter"], + out_of_sample_size=attributes["out_of_sample_size"], + scoring=attributes["scoring"], + with_intercept=attributes["with_intercept"], + time_varying_regression=attributes["time_varying_regression"], + enforce_stationarity=attributes["enforce_stationarity"], + enforce_invertibility=attributes["enforce_invertibility"], + simple_differencing=attributes["simple_differencing"], + measurement_error=attributes["measurement_error"], + mle_regression=attributes["mle_regression"], + hamilton_representation=attributes["hamilton_representation"], + concentrate_scale=attributes["concentrate_scale"], ) def inference(self, data): try: - predict_length = self._attributes['predict_length'] + predict_length = self._attributes["predict_length"] self._model.fit(data) output = self._model.predict(fh=range(predict_length)) output = np.array(output, dtype=np.float64) @@ -957,16 +1041,16 @@ class ExponentialSmoothingModel(BuiltInModel): def __init__(self, attributes): super(ExponentialSmoothingModel, self).__init__(attributes) self._model = ExponentialSmoothing( - damped_trend=attributes['damped_trend'], - initialization_method=attributes['initialization_method'], - optimized=attributes['optimized'], - remove_bias=attributes['remove_bias'], - use_brute=attributes['use_brute'] + damped_trend=attributes["damped_trend"], + initialization_method=attributes["initialization_method"], + optimized=attributes["optimized"], + remove_bias=attributes["remove_bias"], + use_brute=attributes["use_brute"], ) def inference(self, data): try: - predict_length = self._attributes['predict_length'] + predict_length = self._attributes["predict_length"] self._model.fit(data) output = self._model.predict(fh=range(predict_length)) output = np.array(output, dtype=np.float64) @@ -979,13 +1063,12 @@ class NaiveForecasterModel(BuiltInModel): def __init__(self, attributes): super(NaiveForecasterModel, self).__init__(attributes) self._model = NaiveForecaster( - strategy=attributes['strategy'], - sp=attributes['sp'] + strategy=attributes["strategy"], sp=attributes["sp"] ) def inference(self, data): try: - predict_length = self._attributes['predict_length'] + predict_length = self._attributes["predict_length"] self._model.fit(data) output = self._model.predict(fh=range(predict_length)) output = np.array(output, dtype=np.float64) @@ -998,19 +1081,19 @@ class STLForecasterModel(BuiltInModel): def __init__(self, attributes): super(STLForecasterModel, self).__init__(attributes) self._model = STLForecaster( - sp=attributes['sp'], - seasonal=attributes['seasonal'], - seasonal_deg=attributes['seasonal_deg'], - trend_deg=attributes['trend_deg'], - low_pass_deg=attributes['low_pass_deg'], - seasonal_jump=attributes['seasonal_jump'], - trend_jump=attributes['trend_jump'], - low_pass_jump=attributes['low_pass_jump'] + sp=attributes["sp"], + seasonal=attributes["seasonal"], + seasonal_deg=attributes["seasonal_deg"], + trend_deg=attributes["trend_deg"], + low_pass_deg=attributes["low_pass_deg"], + seasonal_jump=attributes["seasonal_jump"], + trend_jump=attributes["trend_jump"], + low_pass_jump=attributes["low_pass_jump"], ) def inference(self, data): try: - predict_length = self._attributes['predict_length'] + predict_length = self._attributes["predict_length"] self._model.fit(data) output = self._model.predict(fh=range(predict_length)) output = np.array(output, dtype=np.float64) @@ -1023,21 +1106,21 @@ class GMMHMMModel(BuiltInModel): def __init__(self, attributes): super(GMMHMMModel, self).__init__(attributes) self._model = GMMHMM( - n_components=attributes['n_components'], - n_mix=attributes['n_mix'], - min_covar=attributes['min_covar'], - startprob_prior=attributes['startprob_prior'], - transmat_prior=attributes['transmat_prior'], - means_prior=attributes['means_prior'], - means_weight=attributes['means_weight'], - weights_prior=attributes['weights_prior'], - algorithm=attributes['algorithm'], - covariance_type=attributes['covariance_type'], - n_iter=attributes['n_iter'], - tol=attributes['tol'], - params=attributes['params'], - init_params=attributes['init_params'], - implementation=attributes['implementation'] + n_components=attributes["n_components"], + n_mix=attributes["n_mix"], + min_covar=attributes["min_covar"], + startprob_prior=attributes["startprob_prior"], + transmat_prior=attributes["transmat_prior"], + means_prior=attributes["means_prior"], + means_weight=attributes["means_weight"], + weights_prior=attributes["weights_prior"], + algorithm=attributes["algorithm"], + covariance_type=attributes["covariance_type"], + n_iter=attributes["n_iter"], + tol=attributes["tol"], + params=attributes["params"], + init_params=attributes["init_params"], + implementation=attributes["implementation"], ) def inference(self, data): @@ -1054,21 +1137,21 @@ class GaussianHmmModel(BuiltInModel): def __init__(self, attributes): super(GaussianHmmModel, self).__init__(attributes) self._model = GaussianHMM( - n_components=attributes['n_components'], - covariance_type=attributes['covariance_type'], - min_covar=attributes['min_covar'], - startprob_prior=attributes['startprob_prior'], - transmat_prior=attributes['transmat_prior'], - means_prior=attributes['means_prior'], - means_weight=attributes['means_weight'], - covars_prior=attributes['covars_prior'], - covars_weight=attributes['covars_weight'], - algorithm=attributes['algorithm'], - n_iter=attributes['n_iter'], - tol=attributes['tol'], - params=attributes['params'], - init_params=attributes['init_params'], - implementation=attributes['implementation'] + n_components=attributes["n_components"], + covariance_type=attributes["covariance_type"], + min_covar=attributes["min_covar"], + startprob_prior=attributes["startprob_prior"], + transmat_prior=attributes["transmat_prior"], + means_prior=attributes["means_prior"], + means_weight=attributes["means_weight"], + covars_prior=attributes["covars_prior"], + covars_weight=attributes["covars_weight"], + algorithm=attributes["algorithm"], + n_iter=attributes["n_iter"], + tol=attributes["tol"], + params=attributes["params"], + init_params=attributes["init_params"], + implementation=attributes["implementation"], ) def inference(self, data): @@ -1085,12 +1168,12 @@ class STRAYModel(BuiltInModel): def __init__(self, attributes): super(STRAYModel, self).__init__(attributes) self._model = STRAY( - alpha=attributes['alpha'], - k=attributes['k'], - knn_algorithm=attributes['knn_algorithm'], - p=attributes['p'], - size_threshold=attributes['size_threshold'], - outlier_tail=attributes['outlier_tail'] + alpha=attributes["alpha"], + k=attributes["k"], + knn_algorithm=attributes["knn_algorithm"], + p=attributes["p"], + size_threshold=attributes["size_threshold"], + outlier_tail=attributes["outlier_tail"], ) def inference(self, data): diff --git a/iotdb-core/ainode/ainode/core/model/model_factory.py b/iotdb-core/ainode/ainode/core/model/model_factory.py index 1700dd28eb64a..702826f9cba56 100644 --- a/iotdb-core/ainode/ainode/core/model/model_factory.py +++ b/iotdb-core/ainode/ainode/core/model/model_factory.py @@ -17,15 +17,20 @@ # import os import shutil -from urllib.parse import urlparse, urljoin +from urllib.parse import urljoin, urlparse import yaml from requests import Session from requests.adapters import HTTPAdapter -from ainode.core.constant import DEFAULT_RECONNECT_TIMES, DEFAULT_RECONNECT_TIMEOUT, DEFAULT_CHUNK_SIZE, \ - DEFAULT_CONFIG_FILE_NAME, DEFAULT_MODEL_FILE_NAME -from ainode.core.exception import InvalidUriError, BadConfigValueError +from ainode.core.constant import ( + DEFAULT_CHUNK_SIZE, + DEFAULT_CONFIG_FILE_NAME, + DEFAULT_MODEL_FILE_NAME, + DEFAULT_RECONNECT_TIMEOUT, + DEFAULT_RECONNECT_TIMES, +) +from ainode.core.exception import BadConfigValueError, InvalidUriError from ainode.core.log import Logger from ainode.core.util.serde import get_data_type_byte_from_str from ainode.thrift.ainode.ttypes import TConfigs @@ -46,12 +51,12 @@ def _parse_uri(uri): """ parse_result = urlparse(uri) - is_network_path = parse_result.scheme in ('http', 'https') + is_network_path = parse_result.scheme in ("http", "https") if is_network_path: return True, uri # handle file:// in uri - if parse_result.scheme == 'file': + if parse_result.scheme == "file": uri = uri[7:] # handle ~ in uri @@ -77,7 +82,7 @@ def _download_file(url: str, storage_path: str) -> None: response = session.get(url, timeout=DEFAULT_RECONNECT_TIMEOUT, stream=True) response.raise_for_status() - with open(storage_path, 'wb') as file: + with open(storage_path, "wb") as file: for chunk in response.iter_content(chunk_size=DEFAULT_CHUNK_SIZE): if chunk: file.write(chunk) @@ -85,8 +90,9 @@ def _download_file(url: str, storage_path: str) -> None: logger.debug(f"download file from {url} to {storage_path} success") -def _register_model_from_network(uri: str, model_storage_path: str, - config_storage_path: str) -> [TConfigs, str]: +def _register_model_from_network( + uri: str, model_storage_path: str, config_storage_path: str +) -> [TConfigs, str]: """ Args: uri: network dir path of model to register, where model.pt and config.yaml are required, @@ -106,7 +112,7 @@ def _register_model_from_network(uri: str, model_storage_path: str, _download_file(target_config_path, config_storage_path) # read and parse config dict from config.yaml - with open(config_storage_path, 'r', encoding='utf-8') as file: + with open(config_storage_path, "r", encoding="utf-8") as file: config_dict = yaml.safe_load(file) configs, attributes = _parse_inference_config(config_dict) @@ -115,8 +121,9 @@ def _register_model_from_network(uri: str, model_storage_path: str, return configs, attributes -def _register_model_from_local(uri: str, model_storage_path: str, - config_storage_path: str) -> [TConfigs, str]: +def _register_model_from_local( + uri: str, model_storage_path: str, config_storage_path: str +) -> [TConfigs, str]: """ Args: uri: local dir path of model to register, where model.pt and config.yaml are required, @@ -141,17 +148,21 @@ def _register_model_from_local(uri: str, model_storage_path: str, # copy config.yaml logger.debug(f"copy file from {target_config_path} to {config_storage_path}") shutil.copy(target_config_path, config_storage_path) - logger.debug(f"copy file from {target_config_path} to {config_storage_path} success") + logger.debug( + f"copy file from {target_config_path} to {config_storage_path} success" + ) # read and parse config dict from config.yaml - with open(config_storage_path, 'r', encoding='utf-8') as file: + with open(config_storage_path, "r", encoding="utf-8") as file: config_dict = yaml.safe_load(file) configs, attributes = _parse_inference_config(config_dict) # if config.yaml is correct, copy model file logger.debug(f"copy file from {target_model_path} to {model_storage_path}") shutil.copy(target_model_path, model_storage_path) - logger.debug(f"copy file from {target_model_path} to {model_storage_path} success") + logger.debug( + f"copy file from {target_model_path} to {model_storage_path} success" + ) elif not exist_model_file or not exist_config_file: raise InvalidUriError(uri) @@ -173,63 +184,108 @@ def _parse_inference_config(config_dict): configs: TConfigs attributes: str """ - configs = config_dict['configs'] + configs = config_dict["configs"] # check if input_shape and output_shape are two-dimensional array - if not (isinstance(configs['input_shape'], list) and len(configs['input_shape']) == 2): - raise BadConfigValueError('input_shape', configs['input_shape'], - 'input_shape should be a two-dimensional array.') - if not (isinstance(configs['output_shape'], list) and len(configs['output_shape']) == 2): - raise BadConfigValueError('output_shape', configs['output_shape'], - 'output_shape should be a two-dimensional array.') + if not ( + isinstance(configs["input_shape"], list) and len(configs["input_shape"]) == 2 + ): + raise BadConfigValueError( + "input_shape", + configs["input_shape"], + "input_shape should be a two-dimensional array.", + ) + if not ( + isinstance(configs["output_shape"], list) and len(configs["output_shape"]) == 2 + ): + raise BadConfigValueError( + "output_shape", + configs["output_shape"], + "output_shape should be a two-dimensional array.", + ) # check if input_shape and output_shape are positive integer - input_shape_is_positive_number = isinstance(configs['input_shape'][0], int) and isinstance( - configs['input_shape'][1], int) and configs['input_shape'][0] > 0 and configs['input_shape'][1] > 0 + input_shape_is_positive_number = ( + isinstance(configs["input_shape"][0], int) + and isinstance(configs["input_shape"][1], int) + and configs["input_shape"][0] > 0 + and configs["input_shape"][1] > 0 + ) if not input_shape_is_positive_number: - raise BadConfigValueError('input_shape', configs['input_shape'], - 'element in input_shape should be positive integer.') - - output_shape_is_positive_number = isinstance(configs['output_shape'][0], int) and isinstance( - configs['output_shape'][1], int) and configs['output_shape'][0] > 0 and configs['output_shape'][1] > 0 + raise BadConfigValueError( + "input_shape", + configs["input_shape"], + "element in input_shape should be positive integer.", + ) + + output_shape_is_positive_number = ( + isinstance(configs["output_shape"][0], int) + and isinstance(configs["output_shape"][1], int) + and configs["output_shape"][0] > 0 + and configs["output_shape"][1] > 0 + ) if not output_shape_is_positive_number: - raise BadConfigValueError('output_shape', configs['output_shape'], - 'element in output_shape should be positive integer.') + raise BadConfigValueError( + "output_shape", + configs["output_shape"], + "element in output_shape should be positive integer.", + ) # check if input_type and output_type are one-dimensional array with right length - if 'input_type' in configs and not ( - isinstance(configs['input_type'], list) and len(configs['input_type']) == configs['input_shape'][1]): - raise BadConfigValueError('input_type', configs['input_type'], - 'input_type should be a one-dimensional array and length of it should be equal to input_shape[1].') - - if 'output_type' in configs and not ( - isinstance(configs['output_type'], list) and len(configs['output_type']) == configs['output_shape'][1]): - raise BadConfigValueError('output_type', configs['output_type'], - 'output_type should be a one-dimensional array and length of it should be equal to output_shape[1].') + if "input_type" in configs and not ( + isinstance(configs["input_type"], list) + and len(configs["input_type"]) == configs["input_shape"][1] + ): + raise BadConfigValueError( + "input_type", + configs["input_type"], + "input_type should be a one-dimensional array and length of it should be equal to input_shape[1].", + ) + + if "output_type" in configs and not ( + isinstance(configs["output_type"], list) + and len(configs["output_type"]) == configs["output_shape"][1] + ): + raise BadConfigValueError( + "output_type", + configs["output_type"], + "output_type should be a one-dimensional array and length of it should be equal to output_shape[1].", + ) # parse input_type and output_type to byte - if 'input_type' in configs: - input_type = [get_data_type_byte_from_str(x) for x in configs['input_type']] + if "input_type" in configs: + input_type = [get_data_type_byte_from_str(x) for x in configs["input_type"]] else: - input_type = [get_data_type_byte_from_str('float32')] * configs['input_shape'][1] + input_type = [get_data_type_byte_from_str("float32")] * configs["input_shape"][ + 1 + ] - if 'output_type' in configs: - output_type = [get_data_type_byte_from_str(x) for x in configs['output_type']] + if "output_type" in configs: + output_type = [get_data_type_byte_from_str(x) for x in configs["output_type"]] else: - output_type = [get_data_type_byte_from_str('float32')] * configs['output_shape'][1] + output_type = [get_data_type_byte_from_str("float32")] * configs[ + "output_shape" + ][1] # parse attributes attributes = "" - if 'attributes' in config_dict: - attributes = str(config_dict['attributes']) + if "attributes" in config_dict: + attributes = str(config_dict["attributes"]) - return TConfigs(configs['input_shape'], configs['output_shape'], input_type, output_type), attributes + return ( + TConfigs( + configs["input_shape"], configs["output_shape"], input_type, output_type + ), + attributes, + ) def fetch_model_by_uri(uri: str, model_storage_path: str, config_storage_path: str): is_network_path, uri = _parse_uri(uri) if is_network_path: - return _register_model_from_network(uri, model_storage_path, config_storage_path) + return _register_model_from_network( + uri, model_storage_path, config_storage_path + ) else: return _register_model_from_local(uri, model_storage_path, config_storage_path) diff --git a/iotdb-core/ainode/ainode/core/model/model_storage.py b/iotdb-core/ainode/ainode/core/model/model_storage.py index 43ebf6c06b796..c0e2a21c80a8d 100644 --- a/iotdb-core/ainode/ainode/core/model/model_storage.py +++ b/iotdb-core/ainode/ainode/core/model/model_storage.py @@ -25,18 +25,20 @@ from pylru import lrucache from ainode.core.config import AINodeDescriptor -from ainode.core.constant import (DEFAULT_MODEL_FILE_NAME, - DEFAULT_CONFIG_FILE_NAME) +from ainode.core.constant import DEFAULT_CONFIG_FILE_NAME, DEFAULT_MODEL_FILE_NAME from ainode.core.exception import ModelNotExistError from ainode.core.log import Logger from ainode.core.model.model_factory import fetch_model_by_uri from ainode.core.util.lock import ModelLockPool + logger = Logger() class ModelStorage(object): def __init__(self): - self._model_dir = os.path.join(os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir()) + self._model_dir = os.path.join( + os.getcwd(), AINodeDescriptor().get_config().get_ain_models_dir() + ) if not os.path.exists(self._model_dir): try: os.makedirs(self._model_dir) @@ -44,7 +46,9 @@ def __init__(self): logger.error(e) raise e self._lock_pool = ModelLockPool() - self._model_cache = lrucache(AINodeDescriptor().get_config().get_ain_model_storage_cache_size()) + self._model_cache = lrucache( + AINodeDescriptor().get_config().get_ain_model_storage_cache_size() + ) def register_model(self, model_id: str, uri: str): """ @@ -56,7 +60,7 @@ def register_model(self, model_id: str, uri: str): configs: TConfigs attributes: str """ - storage_path = os.path.join(self._model_dir, f'{model_id}') + storage_path = os.path.join(self._model_dir, f"{model_id}") # create storage dir if not exist if not os.path.exists(storage_path): os.makedirs(storage_path) @@ -69,12 +73,15 @@ def load_model(self, model_id: str, acceleration: bool) -> Callable: Returns: model: a ScriptModule contains model architecture and parameters, which can be deployed cross-platform """ - ain_models_dir = os.path.join(self._model_dir, f'{model_id}') + ain_models_dir = os.path.join(self._model_dir, f"{model_id}") model_path = os.path.join(ain_models_dir, DEFAULT_MODEL_FILE_NAME) with self._lock_pool.get_lock(model_id).read_lock(): if model_path in self._model_cache: model = self._model_cache[model_path] - if isinstance(model, torch._dynamo.eval_frame.OptimizedModule) or not acceleration: + if ( + isinstance(model, torch._dynamo.eval_frame.OptimizedModule) + or not acceleration + ): return model else: model = torch.compile(model) @@ -89,7 +96,9 @@ def load_model(self, model_id: str, acceleration: bool) -> Callable: try: model = torch.compile(model) except Exception as e: - logger.warning(f"acceleration failed, fallback to normal mode: {str(e)}") + logger.warning( + f"acceleration failed, fallback to normal mode: {str(e)}" + ) self._model_cache[model_path] = model return model @@ -100,7 +109,7 @@ def delete_model(self, model_id: str) -> None: Returns: None """ - storage_path = os.path.join(self._model_dir, f'{model_id}') + storage_path = os.path.join(self._model_dir, f"{model_id}") with self._lock_pool.get_lock(model_id).write_lock(): if os.path.exists(storage_path): for file_name in os.listdir(storage_path): diff --git a/iotdb-core/ainode/ainode/core/model/sundial/__init__.py b/iotdb-core/ainode/ainode/core/model/sundial/__init__.py index 4b8ee97fad2be..2a1e720805f29 100644 --- a/iotdb-core/ainode/ainode/core/model/sundial/__init__.py +++ b/iotdb-core/ainode/ainode/core/model/sundial/__init__.py @@ -14,4 +14,4 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# \ No newline at end of file +# diff --git a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py index 41c54ff4a721a..c903ce3e9dd68 100644 --- a/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py +++ b/iotdb-core/ainode/ainode/core/model/sundial/configuration_sundial.py @@ -17,6 +17,7 @@ # from typing import List + from transformers import PretrainedConfig @@ -41,7 +42,7 @@ def __init__( flow_loss_depth: int = 3, num_sampling_steps: int = 50, diffusion_batch_mul: int = 4, - ckpt_path: str = None, # weight path + ckpt_path: str = None, # weight path **kwargs, ): self.input_token_len = input_token_len diff --git a/iotdb-core/ainode/ainode/core/model/sundial/flow_loss.py b/iotdb-core/ainode/ainode/core/model/sundial/flow_loss.py index 79fcd73c15502..b3fe95dbe2d27 100644 --- a/iotdb-core/ainode/ainode/core/model/sundial/flow_loss.py +++ b/iotdb-core/ainode/ainode/core/model/sundial/flow_loss.py @@ -16,9 +16,10 @@ # under the License. # +import math + import torch import torch.nn as nn -import math class FlowLoss(nn.Module): @@ -32,7 +33,7 @@ def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps model_channels=width, out_channels=target_channels, z_channels=z_channels, - num_res_blocks=depth + num_res_blocks=depth, ) self.num_sampling_steps = num_sampling_steps @@ -44,8 +45,9 @@ def forward(self, target, z, mask=None, mask_y=None): predict_v = self.net(noised_target, t * 1000, z) - weights = 1.0 / \ - torch.arange(1, self.in_channels + 1, dtype=torch.float32, device=target.device) + weights = 1.0 / torch.arange( + 1, self.in_channels + 1, dtype=torch.float32, device=target.device + ) if mask_y is not None: loss = (mask_y * weights * (predict_v - target) ** 2).sum(dim=-1) else: @@ -61,8 +63,7 @@ def sample(self, z, num_samples=1): x = noise dt = 1.0 / self.num_sampling_steps for i in range(self.num_sampling_steps): - t = (torch.ones((x.shape[0])) * i / - self.num_sampling_steps).to(x.device) + t = (torch.ones((x.shape[0])) * i / self.num_sampling_steps).to(x.device) pred = self.net(x, t * 1000, z) x = x + (pred - noise) * dt x = x.reshape(num_samples, -1, self.in_channels).transpose(0, 1) @@ -100,14 +101,16 @@ def timestep_embedding(t, dim, max_period=10000): # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( - -math.log(max_period) * torch.arange(start=0, - end=half, dtype=torch.float32) / half + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( - [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) return embedding def forward(self, t): @@ -122,10 +125,7 @@ class ResBlock(nn.Module): :param channels: the number of input channels. """ - def __init__( - self, - channels - ): + def __init__(self, channels): super().__init__() self.channels = channels @@ -137,13 +137,11 @@ def __init__( ) self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(channels, 3 * channels, bias=True) + nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True) ) def forward(self, x, y): - shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation( - y).chunk(3, dim=-1) + shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) h = modulate(self.in_ln(x), shift_mlp, scale_mlp) h = self.mlp(h) return x + gate_mlp * h @@ -157,11 +155,11 @@ class FinalLayer(nn.Module): def __init__(self, model_channels, out_channels): super().__init__() self.norm_final = nn.LayerNorm( - model_channels, elementwise_affine=False, eps=1e-6) + model_channels, elementwise_affine=False, eps=1e-6 + ) self.linear = nn.Linear(model_channels, out_channels, bias=True) self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(model_channels, 2 * model_channels, bias=True) + nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True) ) def forward(self, x, c): @@ -203,9 +201,11 @@ def __init__( res_blocks = [] for i in range(num_res_blocks): - res_blocks.append(ResBlock( - model_channels, - )) + res_blocks.append( + ResBlock( + model_channels, + ) + ) self.res_blocks = nn.ModuleList(res_blocks) self.final_layer = FinalLayer(model_channels, out_channels) @@ -218,6 +218,7 @@ def _basic_init(module): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) + self.apply(_basic_init) # Initialize timestep embedding MLP diff --git a/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py b/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py index bdb853cd72d39..a74e8e6cf23ca 100644 --- a/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py +++ b/iotdb-core/ainode/ainode/core/model/sundial/modeling_sundial.py @@ -17,27 +17,32 @@ # import os -from typing import Optional, Tuple, List, Union +from typing import List, Optional, Tuple, Union + import torch -from torch import nn import torch.nn.functional as F -from transformers import PreTrainedModel, Cache, DynamicCache +from huggingface_hub import hf_hub_download +from safetensors.torch import load_file as load_safetensors +from torch import nn +from transformers import Cache, DynamicCache, PreTrainedModel from transformers.activations import ACT2FN from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask -from transformers.modeling_outputs import MoeModelOutputWithPast, MoeCausalLMOutputWithPast +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) + +from ainode.core.log import Logger from ainode.core.model.sundial.configuration_sundial import SundialConfig -from ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin from ainode.core.model.sundial.flow_loss import FlowLoss +from ainode.core.model.sundial.ts_generation_mixin import TSGenerationMixin -from safetensors.torch import load_file as load_safetensors -from huggingface_hub import hf_hub_download - -from ainode.core.log import Logger logger = Logger() + def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2:] + x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) @@ -54,25 +59,25 @@ def __init__(self, config: SundialConfig): super().__init__() self.dropout = nn.Dropout(config.dropout_rate) self.hidden_layer = nn.Linear( - config.input_token_len * 2, config.intermediate_size) + config.input_token_len * 2, config.intermediate_size + ) self.act = ACT2FN[config.hidden_act] - self.output_layer = nn.Linear( - config.intermediate_size, config.hidden_size) - self.residual_layer = nn.Linear( - config.input_token_len * 2, config.hidden_size) + self.output_layer = nn.Linear(config.intermediate_size, config.hidden_size) + self.residual_layer = nn.Linear(config.input_token_len * 2, config.hidden_size) self.input_token_len = config.input_token_len def forward(self, x): mask = torch.ones_like(x, dtype=torch.float32) input_length = x.shape[-1] - padding_length = (self.input_token_len - (input_length % - self.input_token_len)) % self.input_token_len + padding_length = ( + self.input_token_len - (input_length % self.input_token_len) + ) % self.input_token_len x = F.pad(x, (padding_length, 0)) mask = F.pad(mask, (padding_length, 0)) - x = x.unfold(dimension=-1, size=self.input_token_len, - step=self.input_token_len) + x = x.unfold(dimension=-1, size=self.input_token_len, step=self.input_token_len) mask = mask.unfold( - dimension=-1, size=self.input_token_len, step=self.input_token_len) + dimension=-1, size=self.input_token_len, step=self.input_token_len + ) x = torch.cat([x, mask], dim=-1) hid = self.act(self.hidden_layer(x)) @@ -88,33 +93,38 @@ def __init__(self, dim, max_position_embeddings=10000, base=10000, device=None): self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, - 2, dtype=torch.int64).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base + ** ( + torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) + / self.dim + ) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, - dtype=torch.int64).type_as(self.inv_freq) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=torch.int64 + ).type_as(self.inv_freq) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer( - "cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer( - "sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache( - seq_len=seq_len, device=x.device, dtype=x.dtype) + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), @@ -135,16 +145,17 @@ def __init__(self, config: SundialConfig, layer_idx: Optional[int] = None): self.v_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.rotary_emb = SundialRotaryEmbedding( - self.head_dim, max_position_embeddings=config.max_position_embeddings) + self.head_dim, max_position_embeddings=config.max_position_embeddings + ) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - **kwargs, + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -153,26 +164,35 @@ def forward( value_states = self.v_proj(hidden_states) query_states = query_states.view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) key_states = key_states.view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) value_states = value_states.view( - bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length( - kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids) + query_states, key_states, cos, sin, position_ids + ) if past_key_value is not None: key_states, value_states = past_key_value.update( - key_states, value_states, self.layer_idx) + key_states, value_states, self.layer_idx + ) attn_output = F.scaled_dot_product_attention( - query_states, key_states, value_states, attention_mask, dropout_p=(self.attention_dropout if self.training else 0.0)) + query_states, + key_states, + value_states, + attention_mask, + dropout_p=(self.attention_dropout if self.training else 0.0), + ) attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) @@ -189,16 +209,15 @@ def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size - self.gate_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear( - self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear( - self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[hidden_act] def forward(self, hidden_state): - return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state)) + return self.down_proj( + self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state) + ) class SundialDecoderLayer(nn.Module): @@ -215,15 +234,20 @@ def __init__(self, config: SundialConfig, layer_idx: int): self.norm2 = torch.nn.LayerNorm(config.hidden_size) def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, - **kwargs, - ) -> Tuple[torch.FloatTensor, torch.FloatTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]: + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + Optional[torch.FloatTensor], + Optional[torch.FloatTensor], + ]: residual = hidden_states hidden_states = self.norm1(hidden_states) @@ -280,44 +304,56 @@ def __init__(self, config: SundialConfig): super().__init__(config) self.embed_layer = SundialPatchEmbedding(config) self.layers = nn.ModuleList( - [SundialDecoderLayer(config, layer_idx) - for layer_idx in range(config.num_hidden_layers)] + [ + SundialDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] ) self.norm = torch.nn.LayerNorm(config.hidden_size) self.gradient_checkpointing = False def forward( - self, - input_ids: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + self, + input_ids: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, ) -> Union[Tuple, MoeModelOutputWithPast]: # input_ids is the input of time series, its shape is [batch_size, seq_len] - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds") + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) if inputs_embeds is None: inputs_embeds = self.embed_layer(input_ids) @@ -332,15 +368,16 @@ def forward( if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache( - past_key_values) - past_key_values_length = past_key_values.get_usable_length( - seq_length) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) # position_ids = position_ids.unsqueeze(0).view(-1, seq_length) position_ids = position_ids.view(-1, seq_length) @@ -402,8 +439,11 @@ def forward( next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache( - ) if use_legacy_cache else next_decoder_cache + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) if not return_dict: return tuple( @@ -424,17 +464,28 @@ def __init__(self, config: SundialConfig): super().__init__(config) self.config = config self.model = SundialModel(self.config) - self.flow_loss = FlowLoss(self.config.output_token_lens[-1], self.config.hidden_size, - self.config.flow_loss_depth, self.config.hidden_size, self.config.num_sampling_steps) + self.flow_loss = FlowLoss( + self.config.output_token_lens[-1], + self.config.hidden_size, + self.config.flow_loss_depth, + self.config.hidden_size, + self.config.num_sampling_steps, + ) # TODO: Unify data loader if not os.path.exists(config.ckpt_path): os.mkdir(config.ckpt_path) weights_path = os.path.join(config.ckpt_path, "model.safetensors") if not os.path.exists(weights_path): - logger.info(f"Weight not found at {weights_path}, downloading from HuggingFace...") + logger.info( + f"Weight not found at {weights_path}, downloading from HuggingFace..." + ) repo_id = "thuml/sundial-base-128m" try: - hf_hub_download(repo_id=repo_id, filename="model.safetensors", local_dir=config.ckpt_path) + hf_hub_download( + repo_id=repo_id, + filename="model.safetensors", + local_dir=config.ckpt_path, + ) logger.info(f"Got weight to {weights_path}") except Exception as e: logger.error(f"Failed to download weight to {weights_path} due to {e}") @@ -449,34 +500,44 @@ def get_decoder(self): return self.model def forward( - self, - input_ids: torch.FloatTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.FloatTensor] = None, - loss_masks: Optional[torch.FloatTensor] = None, - mask_y: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - max_output_length: Optional[int] = None, - revin: Optional[bool] = False, - num_samples: Optional[int] = 1, + self, + input_ids: torch.FloatTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.FloatTensor] = None, + loss_masks: Optional[torch.FloatTensor] = None, + mask_y: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + max_output_length: Optional[int] = None, + revin: Optional[bool] = False, + num_samples: Optional[int] = 1, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if revin: means = input_ids.mean(1, keepdim=True).detach() stdev = input_ids.std(dim=1, keepdim=True, unbiased=False).detach() - stdev = torch.where(stdev > 1e-2, stdev, torch.tensor(1.0, device=input_ids.device)) + stdev = torch.where( + stdev > 1e-2, stdev, torch.tensor(1.0, device=input_ids.device) + ) input_ids = (input_ids - means) / stdev outputs = self.model( input_ids=input_ids, @@ -499,18 +560,23 @@ def forward( labels = (labels - means) / stdev output_token_len = self.config.output_token_lens[-1] seq_len = hidden_states.shape[1] * self.config.input_token_len - labels = labels[:, :seq_len - - self.config.input_token_len + output_token_len] + labels = labels[ + :, : seq_len - self.config.input_token_len + output_token_len + ] shift_labels = labels.unfold( - dimension=-1, size=output_token_len, step=self.config.input_token_len) + dimension=-1, size=output_token_len, step=self.config.input_token_len + ) bsz, L, _ = shift_labels.shape - shift_labels = shift_labels.reshape( - bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1) - hidden_states = hidden_states.reshape( - bsz * L, -1).repeat(self.config.diffusion_batch_mul, 1) - loss_masks = loss_masks.reshape( - bsz * L).repeat(self.config.diffusion_batch_mul) + shift_labels = shift_labels.reshape(bsz * L, -1).repeat( + self.config.diffusion_batch_mul, 1 + ) + hidden_states = hidden_states.reshape(bsz * L, -1).repeat( + self.config.diffusion_batch_mul, 1 + ) + loss_masks = loss_masks.reshape(bsz * L).repeat( + self.config.diffusion_batch_mul + ) mask_y = mask_y.repeat(L * self.config.diffusion_batch_mul, 1) loss = self.flow_loss(shift_labels, hidden_states, loss_masks, mask_y) @@ -546,7 +612,14 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, revin=False, num_samples=1, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + revin=False, + num_samples=1, + **kwargs, ): # Omit tokens covered by past_key_values if past_key_values is not None: @@ -566,21 +639,26 @@ def prepare_inputs_for_generation( # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) - if attention_mask is not None and attention_mask.shape[1] > (input_ids.shape[1] // self.config.input_token_len): - input_ids = input_ids[:, - - (attention_mask.shape[1] - past_length) * self.config.input_token_len:] + if attention_mask is not None and attention_mask.shape[1] > ( + input_ids.shape[1] // self.config.input_token_len + ): + input_ids = input_ids[ + :, + -(attention_mask.shape[1] - past_length) + * self.config.input_token_len :, + ] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < (input_ids.shape[1] // self.config.input_token_len): - input_ids = input_ids[:, past_length * - self.config.input_token_len:] + input_ids = input_ids[:, past_length * self.config.input_token_len :] # 3 - Otherwise (past_length >= (input_ids.shape[1] // self.config.input_token_len)), let's assume input_ids only has unprocessed tokens. # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. if ( - max_cache_length is not None - and attention_mask is not None - and cache_length + (input_ids.shape[1] // self.config.input_token_len) > max_cache_length + max_cache_length is not None + and attention_mask is not None + and cache_length + (input_ids.shape[1] // self.config.input_token_len) + > max_cache_length ): attention_mask = attention_mask[:, -max_cache_length:] @@ -590,8 +668,9 @@ def prepare_inputs_for_generation( position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, - - (input_ids.shape[1] // self.config.input_token_len):] + position_ids = position_ids[ + :, -(input_ids.shape[1] // self.config.input_token_len) : + ] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: @@ -609,4 +688,4 @@ def prepare_inputs_for_generation( "num_samples": num_samples, } ) - return model_inputs \ No newline at end of file + return model_inputs diff --git a/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py index 723b4a1332a72..d894d3d5ed3d0 100644 --- a/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py +++ b/iotdb-core/ainode/ainode/core/model/sundial/ts_generation_mixin.py @@ -17,46 +17,61 @@ # import warnings -from typing import Any, Dict, List, Optional, Union, Callable +from typing import Any, Callable, Dict, List, Optional, Union + import torch from transformers import GenerationMixin, LogitsProcessorList, StoppingCriteriaList -from transformers.generation import validate_stopping_criteria, EosTokenCriteria -from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, \ - GenerateDecoderOnlyOutput, GenerationConfig, GenerateOutput +from transformers.generation import EosTokenCriteria, validate_stopping_criteria +from transformers.generation.utils import ( + GenerateDecoderOnlyOutput, + GenerateEncoderDecoderOutput, + GenerateNonBeamOutput, + GenerateOutput, + GenerationConfig, +) from transformers.utils import ModelOutput class TSGenerationMixin(GenerationMixin): @torch.no_grad() def generate( - self, - inputs: Optional[torch.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, - synced_gpus: Optional[bool] = None, - assistant_model: Optional["PreTrainedModel"] = None, - streamer: Optional["BaseStreamer"] = None, - negative_prompt_ids: Optional[torch.Tensor] = None, - negative_prompt_attention_mask: Optional[torch.Tensor] = None, - revin: Optional[bool] = True, - num_samples: Optional[int] = 1, - **kwargs, + self, + inputs: Optional[torch.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + prefix_allowed_tokens_fn: Optional[ + Callable[[int, torch.Tensor], List[int]] + ] = None, + synced_gpus: Optional[bool] = None, + assistant_model: Optional["PreTrainedModel"] = None, + streamer: Optional["BaseStreamer"] = None, + negative_prompt_ids: Optional[torch.Tensor] = None, + negative_prompt_attention_mask: Optional[torch.Tensor] = None, + revin: Optional[bool] = True, + num_samples: Optional[int] = 1, + **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: if len(inputs.shape) != 2: - raise ValueError('Input shape must be: [batch_size, seq_len]') + raise ValueError("Input shape must be: [batch_size, seq_len]") if revin: means = inputs.mean(dim=-1, keepdim=True) stdev = inputs.std(dim=-1, keepdim=True, unbiased=False) + 1e-5 inputs = (inputs - means) / stdev - outputs = super().generate(inputs=inputs, generation_config=generation_config, - logits_processor=logits_processor, stopping_criteria=stopping_criteria, - prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, synced_gpus=synced_gpus, - assistant_model=assistant_model, streamer=streamer, - negative_prompt_ids=negative_prompt_ids, - negative_prompt_attention_mask=negative_prompt_attention_mask, - num_samples=num_samples, **kwargs) + outputs = super().generate( + inputs=inputs, + generation_config=generation_config, + logits_processor=logits_processor, + stopping_criteria=stopping_criteria, + prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, + synced_gpus=synced_gpus, + assistant_model=assistant_model, + streamer=streamer, + negative_prompt_ids=negative_prompt_ids, + negative_prompt_attention_mask=negative_prompt_attention_mask, + num_samples=num_samples, + **kwargs, + ) if revin: stdev = stdev.unsqueeze(1).repeat(1, num_samples, 1) means = means.unsqueeze(1).repeat(1, num_samples, 1) @@ -64,27 +79,33 @@ def generate( return outputs def _greedy_search( - self, - input_ids: torch.Tensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - output_logits: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - synced_gpus: bool = False, - streamer: Optional["BaseStreamer"] = None, - **model_kwargs, + self, + input_ids: torch.Tensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + output_logits: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + synced_gpus: bool = False, + streamer: Optional["BaseStreamer"] = None, + **model_kwargs, ) -> Union[GenerateNonBeamOutput, torch.Tensor]: input_ids = input_ids.to(self.device) batch_size, cur_len = input_ids.shape # init values - logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() - stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + logits_processor = ( + logits_processor if logits_processor is not None else LogitsProcessorList() + ) + stopping_criteria = ( + stopping_criteria + if stopping_criteria is not None + else StoppingCriteriaList() + ) if max_length is not None: warnings.warn( "`max_length` is deprecated in this function, use" @@ -92,31 +113,44 @@ def _greedy_search( UserWarning, ) stopping_criteria = validate_stopping_criteria( - stopping_criteria, max_length) - pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + stopping_criteria, max_length + ) + pad_token_id = ( + pad_token_id + if pad_token_id is not None + else self.generation_config.pad_token_id + ) if eos_token_id is not None: - stopping_criteria.append( - EosTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) else: # remove when the method is totally private # need to get `eos_token_id` and add stopping criteria, so that generation does not go forever eos_token_id = [ - criteria.eos_token_id.tolist() for criteria in stopping_criteria if hasattr(criteria, "eos_token_id") + criteria.eos_token_id.tolist() + for criteria in stopping_criteria + if hasattr(criteria, "eos_token_id") ] eos_token_id = eos_token_id[0] if eos_token_id else None if eos_token_id is None and self.generation_config.eos_token_id is not None: eos_token_id = self.generation_config.eos_token_id - stopping_criteria.append( - EosTokenCriteria(eos_token_id=eos_token_id)) + stopping_criteria.append(EosTokenCriteria(eos_token_id=eos_token_id)) if isinstance(eos_token_id, int): eos_token_id = [eos_token_id] - output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_scores = ( + output_scores + if output_scores is not None + else self.generation_config.output_scores + ) output_attentions = ( - output_attentions if output_attentions is not None else self.generation_config.output_attentions + output_attentions + if output_attentions is not None + else self.generation_config.output_attentions ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.generation_config.output_hidden_states ) return_dict_in_generate = ( return_dict_in_generate @@ -127,18 +161,27 @@ def _greedy_search( # init attention / hidden states / scores tuples raw_logits = () if (return_dict_in_generate and output_logits) else None scores = () if (return_dict_in_generate and output_scores) else None - decoder_attentions = () if (return_dict_in_generate and output_attentions) else None - cross_attentions = () if (return_dict_in_generate and output_attentions) else None - decoder_hidden_states = () if ( - return_dict_in_generate and output_hidden_states) else None + decoder_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + cross_attentions = ( + () if (return_dict_in_generate and output_attentions) else None + ) + decoder_hidden_states = ( + () if (return_dict_in_generate and output_hidden_states) else None + ) # if model is an encoder-decoder, retrieve encoder attention weights and hidden states if return_dict_in_generate and self.config.is_encoder_decoder: - encoder_attentions = model_kwargs["encoder_outputs"].get( - "attentions") if output_attentions else None + encoder_attentions = ( + model_kwargs["encoder_outputs"].get("attentions") + if output_attentions + else None + ) encoder_hidden_states = ( - model_kwargs["encoder_outputs"].get( - "hidden_states") if output_hidden_states else None + model_kwargs["encoder_outputs"].get("hidden_states") + if output_hidden_states + else None ) # keep track of which sequences are already finished @@ -146,17 +189,22 @@ def _greedy_search( cur_len = model_kwargs["inputs_embeds"].shape[1] this_peer_finished = False unfinished_sequences = torch.ones( - batch_size, dtype=torch.long, device=input_ids.device) - model_kwargs["cache_position"] = torch.arange( - cur_len, device=input_ids.device) - true_seq_len = (cur_len + self.config.input_token_len - 1) // self.config.input_token_len - model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, -true_seq_len:] + batch_size, dtype=torch.long, device=input_ids.device + ) + model_kwargs["cache_position"] = torch.arange(cur_len, device=input_ids.device) + true_seq_len = ( + cur_len + self.config.input_token_len - 1 + ) // self.config.input_token_len + model_kwargs["attention_mask"] = model_kwargs["attention_mask"][ + :, -true_seq_len: + ] max_length = stopping_criteria.max_length generate_results = None - while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): + while self._has_unfinished_sequences( + this_peer_finished, synced_gpus, device=input_ids.device + ): # prepare model inputs - model_inputs = self.prepare_inputs_for_generation( - input_ids, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) input_length = input_ids.shape[1] @@ -184,8 +232,9 @@ def _greedy_search( raw_logits += (next_token_logits,) if output_attentions: decoder_attentions += ( - (outputs.decoder_attentions,) if self.config.is_encoder_decoder else ( - outputs.attentions,) + (outputs.decoder_attentions,) + if self.config.is_encoder_decoder + else (outputs.attentions,) ) if self.config.is_encoder_decoder: cross_attentions += (outputs.cross_attentions,) @@ -205,9 +254,11 @@ def _greedy_search( if eos_token_id is not None: if pad_token_id is None: raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") - next_tokens = next_tokens * unfinished_sequences + \ - pad_token_id * (1 - unfinished_sequences) + "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." + ) + next_tokens = next_tokens * unfinished_sequences + pad_token_id * ( + 1 - unfinished_sequences + ) # update generated ids, model inputs, and length for next step horizon_length = next_tokens.shape[-1] // self.config.input_token_len @@ -228,7 +279,8 @@ def _greedy_search( is_encoder_decoder=self.config.is_encoder_decoder, ) unfinished_sequences = unfinished_sequences & ~stopping_criteria( - input_ids, scores) + input_ids, scores + ) this_peer_finished = unfinished_sequences.max() == 0 if input_ids.shape[-1] > max_length: @@ -260,15 +312,15 @@ def _greedy_search( past_key_values=model_kwargs.get("past_key_values"), ) else: - return generate_results[:, :, :(max_length - cur_len)] + return generate_results[:, :, : (max_length - cur_len)] def _update_model_kwargs_for_generation( - self, - outputs: ModelOutput, - model_kwargs: Dict[str, Any], - horizon_length: int = 1, - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, + self, + outputs: ModelOutput, + model_kwargs: Dict[str, Any], + horizon_length: int = 1, + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = self._extract_past_from_model_output( @@ -281,26 +333,42 @@ def _update_model_kwargs_for_generation( if "token_type_ids" in model_kwargs: token_type_ids = model_kwargs["token_type_ids"] model_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1) + [token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1 + ) if not is_encoder_decoder: # update attention mask if "attention_mask" in model_kwargs: attention_mask = model_kwargs["attention_mask"] model_kwargs["attention_mask"] = torch.cat( - [attention_mask, attention_mask.new_ones((attention_mask.shape[0], horizon_length))], dim=-1 + [ + attention_mask, + attention_mask.new_ones( + (attention_mask.shape[0], horizon_length) + ), + ], + dim=-1, ) else: # update decoder attention mask if "decoder_attention_mask" in model_kwargs: decoder_attention_mask = model_kwargs["decoder_attention_mask"] model_kwargs["decoder_attention_mask"] = torch.cat( - [decoder_attention_mask, decoder_attention_mask.new_ones( - (decoder_attention_mask.shape[0], horizon_length))], + [ + decoder_attention_mask, + decoder_attention_mask.new_ones( + (decoder_attention_mask.shape[0], horizon_length) + ), + ], dim=-1, ) - if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: - model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + horizon_length + if ( + "cache_position" in model_kwargs + and model_kwargs["cache_position"] is not None + ): + model_kwargs["cache_position"] = ( + model_kwargs["cache_position"][-1:] + horizon_length + ) - return model_kwargs \ No newline at end of file + return model_kwargs diff --git a/iotdb-core/ainode/ainode/core/script.py b/iotdb-core/ainode/ainode/core/script.py index b27bb6ab61bc0..84a44924828b7 100644 --- a/iotdb-core/ainode/ainode/core/script.py +++ b/iotdb-core/ainode/ainode/core/script.py @@ -24,31 +24,39 @@ from ainode.core.client import ClientManager from ainode.core.config import AINodeDescriptor -from ainode.core.constant import TSStatusCode, AINODE_SYSTEM_FILE_NAME +from ainode.core.constant import AINODE_SYSTEM_FILE_NAME, TSStatusCode from ainode.core.exception import MissingConfigError from ainode.core.log import Logger from ainode.core.service import RPCService -from ainode.thrift.common.ttypes import TAINodeLocation, TEndPoint, TAINodeConfiguration, TNodeResource +from ainode.thrift.common.ttypes import ( + TAINodeConfiguration, + TAINodeLocation, + TEndPoint, + TNodeResource, +) from ainode.thrift.confignode.ttypes import TNodeVersionInfo logger = Logger() def _generate_configuration() -> TAINodeConfiguration: - location = TAINodeLocation(AINodeDescriptor().get_config().get_ainode_id(), - TEndPoint(AINodeDescriptor().get_config().get_ain_inference_rpc_address(), - AINodeDescriptor().get_config().get_ain_inference_rpc_port())) - resource = TNodeResource( - int(psutil.cpu_count()), - int(psutil.virtual_memory()[0]) + location = TAINodeLocation( + AINodeDescriptor().get_config().get_ainode_id(), + TEndPoint( + AINodeDescriptor().get_config().get_ain_inference_rpc_address(), + AINodeDescriptor().get_config().get_ain_inference_rpc_port(), + ), ) + resource = TNodeResource(int(psutil.cpu_count()), int(psutil.virtual_memory()[0])) return TAINodeConfiguration(location, resource) def _generate_version_info() -> TNodeVersionInfo: - return TNodeVersionInfo(AINodeDescriptor().get_config().get_version_info(), - AINodeDescriptor().get_config().get_build_info()) + return TNodeVersionInfo( + AINodeDescriptor().get_config().get_version_info(), + AINodeDescriptor().get_config().get_build_info(), + ) def _check_path_permission(): @@ -64,44 +72,58 @@ def _check_path_permission(): def start_ainode(): _check_path_permission() - system_properties_file = os.path.join(AINodeDescriptor().get_config().get_ain_system_dir(), AINODE_SYSTEM_FILE_NAME) + system_properties_file = os.path.join( + AINodeDescriptor().get_config().get_ain_system_dir(), AINODE_SYSTEM_FILE_NAME + ) if not os.path.exists(system_properties_file): # If the system.properties file does not exist, the AINode will register to ConfigNode. try: - logger.info('IoTDB-AINode is registering to ConfigNode...') - ainode_id = ClientManager().borrow_config_node_client().node_register( - AINodeDescriptor().get_config().get_cluster_name(), - _generate_configuration(), - _generate_version_info()) + logger.info("IoTDB-AINode is registering to ConfigNode...") + ainode_id = ( + ClientManager() + .borrow_config_node_client() + .node_register( + AINodeDescriptor().get_config().get_cluster_name(), + _generate_configuration(), + _generate_version_info(), + ) + ) AINodeDescriptor().get_config().set_ainode_id(ainode_id) system_properties = { - 'ainode_id': ainode_id, - 'cluster_name': AINodeDescriptor().get_config().get_cluster_name(), - 'iotdb_version': AINodeDescriptor().get_config().get_version_info(), - 'commit_id': AINodeDescriptor().get_config().get_build_info(), - 'ain_rpc_address': AINodeDescriptor().get_config().get_ain_inference_rpc_address(), - 'ain_rpc_port': AINodeDescriptor().get_config().get_ain_inference_rpc_port(), - 'config_node_list': AINodeDescriptor().get_config().get_ain_target_config_node_list(), + "ainode_id": ainode_id, + "cluster_name": AINodeDescriptor().get_config().get_cluster_name(), + "iotdb_version": AINodeDescriptor().get_config().get_version_info(), + "commit_id": AINodeDescriptor().get_config().get_build_info(), + "ain_rpc_address": AINodeDescriptor() + .get_config() + .get_ain_inference_rpc_address(), + "ain_rpc_port": AINodeDescriptor() + .get_config() + .get_ain_inference_rpc_port(), + "config_node_list": AINodeDescriptor() + .get_config() + .get_ain_target_config_node_list(), } - with open(system_properties_file, 'w') as f: - f.write('#' + str(datetime.now()) + '\n') + with open(system_properties_file, "w") as f: + f.write("#" + str(datetime.now()) + "\n") for key, value in system_properties.items(): - f.write(key + '=' + str(value) + '\n') + f.write(key + "=" + str(value) + "\n") except Exception as e: - logger.error('IoTDB-AINode failed to register to ConfigNode: {}'.format(e)) + logger.error("IoTDB-AINode failed to register to ConfigNode: {}".format(e)) raise e else: # If the system.properties file does exist, the AINode will just restart. try: - logger.info('IoTDB-AINode is restarting...') + logger.info("IoTDB-AINode is restarting...") ClientManager().borrow_config_node_client().node_restart( AINodeDescriptor().get_config().get_cluster_name(), _generate_configuration(), - _generate_version_info()) + _generate_version_info(), + ) except Exception as e: - logger.error('IoTDB-AINode failed to restart: {}'.format(e)) + logger.error("IoTDB-AINode failed to restart: {}".format(e)) raise e rpc_service = RPCService() @@ -110,39 +132,48 @@ def start_ainode(): if rpc_service.exit_code != 0: return - logger.info('IoTDB-AINode has successfully started.') + logger.info("IoTDB-AINode has successfully started.") def remove_ainode(arguments): # Delete the current node if len(arguments) == 2: target_ainode_id = AINodeDescriptor().get_config().get_ainode_id() - target_rpc_address = AINodeDescriptor().get_config().get_ain_inference_rpc_address() + target_rpc_address = ( + AINodeDescriptor().get_config().get_ain_inference_rpc_address() + ) target_rpc_port = AINodeDescriptor().get_config().get_ain_inference_rpc_port() # Delete the node with a given id elif len(arguments) == 3: target_ainode_id = int(arguments[2]) - ainode_configuration_map = ClientManager().borrow_config_node_client().get_ainode_configuration( - target_ainode_id) + ainode_configuration_map = ( + ClientManager() + .borrow_config_node_client() + .get_ainode_configuration(target_ainode_id) + ) end_point = ainode_configuration_map[target_ainode_id].location.internalEndPoint target_rpc_address = end_point.ip target_rpc_port = end_point.port if not end_point: - raise MissingConfigError("NodeId: {} not found in cluster ".format(target_ainode_id)) + raise MissingConfigError( + "NodeId: {} not found in cluster ".format(target_ainode_id) + ) - logger.info('Got target AINode id: {}'.format(target_ainode_id)) + logger.info("Got target AINode id: {}".format(target_ainode_id)) else: raise MissingConfigError("Invalid command") - location = TAINodeLocation(target_ainode_id, TEndPoint(target_rpc_address, target_rpc_port)) + location = TAINodeLocation( + target_ainode_id, TEndPoint(target_rpc_address, target_rpc_port) + ) status = ClientManager().borrow_config_node_client().node_remove(location) if status.code == TSStatusCode.SUCCESS_STATUS.get_status_code(): - logger.info('IoTDB-AINode has successfully removed.') + logger.info("IoTDB-AINode has successfully removed.") if os.path.exists(AINodeDescriptor().get_config().get_ain_models_dir()): shutil.rmtree(AINodeDescriptor().get_config().get_ain_models_dir()) @@ -155,14 +186,14 @@ def main(): logger.info("Command line argument must be specified.") return command = arguments[1] - if command == 'start': + if command == "start": try: - logger.info('IoTDB-AINode is starting...') + logger.info("IoTDB-AINode is starting...") start_ainode() except Exception as e: logger.error("Start AINode failed, because of: {}".format(e)) sys.exit(1) - elif command == 'remove': + elif command == "remove": try: logger.info("Removing AINode...") remove_ainode(arguments) @@ -173,5 +204,5 @@ def main(): logger.warning("Unknown argument: {}.".format(command)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/iotdb-core/ainode/ainode/core/service.py b/iotdb-core/ainode/ainode/core/service.py index 9532093d1dade..7602ebe9f192d 100644 --- a/iotdb-core/ainode/ainode/core/service.py +++ b/iotdb-core/ainode/ainode/core/service.py @@ -17,7 +17,7 @@ # import threading -from thrift.protocol import TCompactProtocol, TBinaryProtocol +from thrift.protocol import TBinaryProtocol, TCompactProtocol from thrift.server import TServer from thrift.transport import TSocket, TTransport @@ -34,15 +34,19 @@ def __init__(self): self.exit_code = 0 super().__init__() processor = IAINodeRPCService.Processor(handler=AINodeRPCServiceHandler()) - transport = TSocket.TServerSocket(host=AINodeDescriptor().get_config().get_ain_inference_rpc_address(), - port=AINodeDescriptor().get_config().get_ain_inference_rpc_port()) + transport = TSocket.TServerSocket( + host=AINodeDescriptor().get_config().get_ain_inference_rpc_address(), + port=AINodeDescriptor().get_config().get_ain_inference_rpc_port(), + ) transport_factory = TTransport.TFramedTransportFactory() if AINodeDescriptor().get_config().get_ain_thrift_compression_enabled(): protocol_factory = TCompactProtocol.TCompactProtocolFactory() else: protocol_factory = TBinaryProtocol.TBinaryProtocolFactory() - self.__pool_server = TServer.TThreadPoolServer(processor, transport, transport_factory, protocol_factory) + self.__pool_server = TServer.TThreadPoolServer( + processor, transport, transport_factory, protocol_factory + ) def run(self) -> None: logger.info("The RPC service thread begin to run...") diff --git a/iotdb-core/ainode/ainode/core/util/activation.py b/iotdb-core/ainode/ainode/core/util/activation.py index 25be5dc2b3996..ce7fa364b5e01 100644 --- a/iotdb-core/ainode/ainode/core/util/activation.py +++ b/iotdb-core/ainode/ainode/core/util/activation.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. # +from collections import OrderedDict + import torch import torch.nn as nn import torch.nn.functional as F -from collections import OrderedDict + class ClampedGELU(nn.Module): def __init__(self, min_val=-10, max_val=10): @@ -30,12 +32,14 @@ def __init__(self, min_val=-10, max_val=10): def forward(self, x): return torch.clamp(self.act(x), self.min_val, self.max_val) + class ClassInstantier(OrderedDict): def __getitem__(self, key): content = super().__getitem__(key) cls, kwargs = content if isinstance(content, tuple) else (content, {}) return cls(**kwargs) + ACT2CLS = { "gelu": nn.GELU, "gelu_10": (ClampedGELU, {"min": -10, "max": 10}), @@ -48,4 +52,4 @@ def __getitem__(self, key): "tanh": nn.Tanh, "prelu": nn.PReLU, } -ACT2FN = ClassInstantier(ACT2CLS) \ No newline at end of file +ACT2FN = ClassInstantier(ACT2CLS) diff --git a/iotdb-core/ainode/ainode/core/util/huggingface_cache.py b/iotdb-core/ainode/ainode/core/util/huggingface_cache.py index 1f8516f33defa..d6365ee717db5 100644 --- a/iotdb-core/ainode/ainode/core/util/huggingface_cache.py +++ b/iotdb-core/ainode/ainode/core/util/huggingface_cache.py @@ -21,13 +21,16 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple + import torch import torch.nn as nn + class Cache: """ Base, abstract class for all caches. The actual data structure is specific to each subclass. """ + # def __init__(self): # # to avoid torch.jit.script error # super().__init__() @@ -61,13 +64,19 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" - raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + raise NotImplementedError( + "Make sure to implement `get_seq_length` in a subclass." + ) def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states, if there is any.""" - raise NotImplementedError("Make sure to implement `get_max_length` in a subclass.") + raise NotImplementedError( + "Make sure to implement `get_max_length` in a subclass." + ) - def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + def get_usable_length( + self, new_seq_length: int, layer_idx: Optional[int] = 0 + ) -> int: """Given the sequence length of the new inputs, returns the usable length of the cache.""" # Cache without size limit -> all cache is usable # Cache with size limit -> if the length cache plus the length of the new inputs is larger the maximum cache @@ -84,7 +93,8 @@ def seen_tokens(self): return self._seen_tokens else: return None - + + class DynamicCache(Cache): """ A cache that grows dynamically as more tokens are generated. This is the default for generative models. @@ -96,7 +106,9 @@ class DynamicCache(Cache): def __init__(self) -> None: self.key_cache: List[torch.Tensor] = [] self.value_cache: List[torch.Tensor] = [] - self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen + self._seen_tokens = ( + 0 # Used in `generate` to keep tally of how many tokens the cache has seen + ) def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: """ @@ -106,7 +118,9 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: if layer_idx < len(self): return (self.key_cache[layer_idx], self.value_cache[layer_idx]) else: - raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + raise KeyError( + f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}" + ) # def __iter__(self): # """ @@ -154,8 +168,12 @@ def update( self.key_cache.append(key_states) self.value_cache.append(value_states) else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx], value_states], dim=-2 + ) return self.key_cache[layer_idx], self.value_cache[layer_idx] @@ -173,9 +191,13 @@ def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select( + 0, beam_idx.to(device) + ) def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" @@ -184,14 +206,18 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) return legacy_cache - def init_data(self, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None): + def init_data( + self, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None + ): if past_key_values is not None: for layer_idx in range(len(past_key_values)): key_states, value_states = past_key_values[layer_idx] self.update(key_states, value_states, layer_idx) @classmethod - def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + def from_legacy_cache( + cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "DynamicCache": """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" cache = cls() if past_key_values is not None: @@ -199,4 +225,3 @@ def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTens key_states, value_states = past_key_values[layer_idx] cache.update(key_states, value_states, layer_idx) return cache - \ No newline at end of file diff --git a/iotdb-core/ainode/ainode/core/util/masking.py b/iotdb-core/ainode/ainode/core/util/masking.py index 826e05d67cfb6..a182d55fd7f3a 100644 --- a/iotdb-core/ainode/ainode/core/util/masking.py +++ b/iotdb-core/ainode/ainode/core/util/masking.py @@ -17,43 +17,54 @@ # import torch -class TriangularCausalMask(): + +class TriangularCausalMask: def __init__(self, B, L, device="cpu"): mask_shape = [B, 1, L, L] with torch.no_grad(): - self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + self._mask = torch.triu( + torch.ones(mask_shape, dtype=torch.bool), diagonal=1 + ).to(device) @property def mask(self): return self._mask -class TimerMultivariateMask(): + +class TimerMultivariateMask: def __init__(self, B, n_vars, n_tokens, device="cpu"): mask_shape = [B, 1, n_tokens, n_tokens] with torch.no_grad(): self._mask1 = torch.ones((n_vars, n_vars), dtype=torch.bool).to(device) - self._mask2 = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device) + self._mask2 = torch.triu( + torch.ones(mask_shape, dtype=torch.bool), diagonal=1 + ).to(device) self._mask = torch.kron(self._mask1, self._mask2) + @property def mask(self): return self._mask -class TimerCovariateMask(): + +class TimerCovariateMask: def __init__(self, B, n_vars, n_tokens, device="cpu"): mask_shape = [B, 1, n_tokens, n_tokens] with torch.no_grad(): self._mask1 = torch.eye(n_vars, dtype=torch.bool).to(device) - self._mask2 = torch.tril(torch.ones(mask_shape, dtype=torch.bool)).to(device) + self._mask2 = torch.tril(torch.ones(mask_shape, dtype=torch.bool)).to( + device + ) self._mask = ~torch.kron(self._mask1, self._mask2) self._mask[:, :, -n_tokens:, :-n_tokens] = False - + @property def mask(self): return self._mask - + + def prepare_4d_causal_attention_mask( attention_mask, - input_shape, # (B, T_query) + input_shape, # (B, T_query) inputs_embeds: torch.Tensor, past_key_values_length: int = 0, ): @@ -62,20 +73,21 @@ def prepare_4d_causal_attention_mask( dtype, device = inputs_embeds.dtype, inputs_embeds.device # 1) causal mask - q_pos = torch.arange(past_key_values_length, - past_key_values_length + T, device=device) # [T] - k_pos = torch.arange(S, device=device) # [S] - causal = (k_pos.unsqueeze(0) <= q_pos.unsqueeze(1)) # [T,S] bool + q_pos = torch.arange( + past_key_values_length, past_key_values_length + T, device=device + ) # [T] + k_pos = torch.arange(S, device=device) # [S] + causal = k_pos.unsqueeze(0) <= q_pos.unsqueeze(1) # [T,S] bool mask = torch.zeros((T, S), dtype=dtype, device=device) - mask.masked_fill_(~causal, torch.finfo(dtype).min) # unvisible → -inf - mask = mask.unsqueeze(0).unsqueeze(0) # [1,1,T,S] + mask.masked_fill_(~causal, torch.finfo(dtype).min) # unvisible → -inf + mask = mask.unsqueeze(0).unsqueeze(0) # [1,1,T,S] # 2) padding mask if attention_mask is not None: pad = (1.0 - attention_mask.to(dtype)) * torch.finfo(dtype).min # [B,S] - pad = pad[:, None, None, :] # [B,1,1,S] + pad = pad[:, None, None, :] # [B,1,1,S] else: - pad = 0. + pad = 0.0 - return mask + pad # [B,1,T,S] + return mask + pad # [B,1,T,S] diff --git a/iotdb-core/ainode/ainode/core/util/serde.py b/iotdb-core/ainode/ainode/core/util/serde.py index 70b86d6609596..affd96992ee9e 100644 --- a/iotdb-core/ainode/ainode/core/util/serde.py +++ b/iotdb-core/ainode/ainode/core/util/serde.py @@ -72,7 +72,7 @@ def convert_to_binary(data_frame: pd.DataFrame): binary += position_count.to_bytes(4, byteorder="big") # column encoding - binary += b'\x02' + binary += b"\x02" for data_type in data_frame.dtypes: binary += _get_encoder(data_type) @@ -90,7 +90,7 @@ def convert_to_binary(data_frame: pd.DataFrame): col = data_frame[keys[i]] for j in range(position_count): value = col[j] - if value.dtype.byteorder != '>': + if value.dtype.byteorder != ">": value = value.byteswap() binary += value.tobytes() @@ -99,44 +99,50 @@ def convert_to_binary(data_frame: pd.DataFrame): def _get_encoder(data_type: pd.Series): if data_type == "bool": - return b'\x00' + return b"\x00" elif data_type == "int32" or data_type == "float32": - return b'\x01' + return b"\x01" elif data_type == "int64" or data_type == "float64": - return b'\x02' + return b"\x02" elif data_type == "texr": - return b'\x03' + return b"\x03" def _get_type_in_byte(data_type: pd.Series): - if data_type == 'bool': - return b'\x00' - elif data_type == 'int32': - return b'\x01' - elif data_type == 'int64': - return b'\x02' - elif data_type == 'float32': - return b'\x03' - elif data_type == 'float64': - return b'\x04' - elif data_type == 'text': - return b'\x05' + if data_type == "bool": + return b"\x00" + elif data_type == "int32": + return b"\x01" + elif data_type == "int64": + return b"\x02" + elif data_type == "float32": + return b"\x03" + elif data_type == "float64": + return b"\x04" + elif data_type == "text": + return b"\x05" else: - raise BadConfigValueError('data_type', data_type, - "data_type should be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']") + raise BadConfigValueError( + "data_type", + data_type, + "data_type should be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']", + ) # General Methods def get_data_type_byte_from_str(value): - ''' + """ Args: value (str): data type in ['bool', 'int32', 'int64', 'float32', 'float64', 'text'] Returns: byte: corresponding data type in [b'\x00', b'\x01', b'\x02', b'\x03', b'\x04', b'\x05'] - ''' - if value not in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']: - raise BadConfigValueError('data_type', value, - "data_type should be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']") + """ + if value not in ["bool", "int32", "int64", "float32", "float64", "text"]: + raise BadConfigValueError( + "data_type", + value, + "data_type should be in ['bool', 'int32', 'int64', 'float32', 'float64', 'text']", + ) if value == "bool": return TSDataType.BOOLEAN.value elif value == "int32": diff --git a/iotdb-core/ainode/poetry.lock b/iotdb-core/ainode/poetry.lock index bf82e69472327..cb6f7e188e02c 100644 --- a/iotdb-core/ainode/poetry.lock +++ b/iotdb-core/ainode/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. [[package]] name = "alembic" @@ -13,8 +13,6 @@ files = [ ] [package.dependencies] -importlib-metadata = {version = "*", markers = "python_version < \"3.9\""} -importlib-resources = {version = "*", markers = "python_version < \"3.9\""} Mako = "*" SQLAlchemy = ">=1.3.0" typing-extensions = ">=4" @@ -22,6 +20,73 @@ typing-extensions = ">=4" [package.extras] tz = ["backports.zoneinfo ; python_version < \"3.9\""] +[[package]] +name = "apache-iotdb" +version = "2.0.4.dev0" +description = "Apache IoTDB client API" +optional = false +python-versions = ">=3.6" +groups = ["main"] +files = [ + {file = "apache_iotdb-2.0.4.dev0-py3-none-any.whl", hash = "sha256:364639e357bf4f15032c401d65711f1a87aee9943a7886c0cfd38dad5daf214e"}, + {file = "apache_iotdb-2.0.4.dev0.tar.gz", hash = "sha256:c43d6619d988b6a1bca1ae9ea183b42052fbee9c98fd42482df83685ccf96d7d"}, +] + +[package.dependencies] +numpy = ">=1.0.0" +pandas = ">=1.0.0" +sqlalchemy = ">=1.4" +sqlalchemy-utils = ">=0.37.8" +thrift = ">=0.14.1" +tzlocal = ">=4.0" + +[[package]] +name = "black" +version = "25.1.0" +description = "The uncompromising code formatter." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "black-25.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:759e7ec1e050a15f89b770cefbf91ebee8917aac5c20483bc2d80a6c3a04df32"}, + {file = "black-25.1.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0e519ecf93120f34243e6b0054db49c00a35f84f195d5bce7e9f5cfc578fc2da"}, + {file = "black-25.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:055e59b198df7ac0b7efca5ad7ff2516bca343276c466be72eb04a3bcc1f82d7"}, + {file = "black-25.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:db8ea9917d6f8fc62abd90d944920d95e73c83a5ee3383493e35d271aca872e9"}, + {file = "black-25.1.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a39337598244de4bae26475f77dda852ea00a93bd4c728e09eacd827ec929df0"}, + {file = "black-25.1.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:96c1c7cd856bba8e20094e36e0f948718dc688dba4a9d78c3adde52b9e6c2299"}, + {file = "black-25.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bce2e264d59c91e52d8000d507eb20a9aca4a778731a08cfff7e5ac4a4bb7096"}, + {file = "black-25.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:172b1dbff09f86ce6f4eb8edf9dede08b1fce58ba194c87d7a4f1a5aa2f5b3c2"}, + {file = "black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b"}, + {file = "black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc"}, + {file = "black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f"}, + {file = "black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba"}, + {file = "black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f"}, + {file = "black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3"}, + {file = "black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171"}, + {file = "black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18"}, + {file = "black-25.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a1ee0a0c330f7b5130ce0caed9936a904793576ef4d2b98c40835d6a65afa6a0"}, + {file = "black-25.1.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3df5f1bf91d36002b0a75389ca8663510cf0531cca8aa5c1ef695b46d98655f"}, + {file = "black-25.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d9e6827d563a2c820772b32ce8a42828dc6790f095f441beef18f96aa6f8294e"}, + {file = "black-25.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:bacabb307dca5ebaf9c118d2d2f6903da0d62c9faa82bd21a33eecc319559355"}, + {file = "black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717"}, + {file = "black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666"}, +] + +[package.dependencies] +click = ">=8.0.0" +mypy-extensions = ">=0.4.3" +packaging = ">=22.0" +pathspec = ">=0.9.0" +platformdirs = ">=2" +tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} +typing-extensions = {version = ">=4.0.1", markers = "python_version < \"3.11\""} + +[package.extras] +colorama = ["colorama (>=0.4.3)"] +d = ["aiohttp (>=3.10)"] +jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"] +uvloop = ["uvloop (>=0.15.2)"] + [[package]] name = "certifi" version = "2024.7.4" @@ -134,6 +199,38 @@ files = [ {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] +[[package]] +name = "click" +version = "8.1.8" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.7" +groups = ["main"] +markers = "python_version == \"3.9\"" +files = [ + {file = "click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2"}, + {file = "click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[[package]] +name = "click" +version = "8.2.1" +description = "Composable command line interface toolkit" +optional = false +python-versions = ">=3.10" +groups = ["main"] +markers = "python_version >= \"3.10\"" +files = [ + {file = "click-8.2.1-py3-none-any.whl", hash = "sha256:61a3265b914e850b85317d0b3109c7f8cd35a670f963866005d6ef1d5175a12b"}, + {file = "click-8.2.1.tar.gz", hash = "sha256:27c491cc05d968d271d5a1db13e3b5a184636d9d930f148c50b038f0d0646202"}, +] + +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + [[package]] name = "colorama" version = "0.4.6" @@ -141,7 +238,7 @@ description = "Cross-platform colored terminal text." optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" groups = ["main"] -markers = "sys_platform == \"win32\" or platform_system == \"Windows\"" +markers = "platform_system == \"Windows\" or sys_platform == \"win32\"" files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -263,6 +360,18 @@ toml = ["toml"] vault = ["hvac"] yaml = ["ruamel.yaml"] +[[package]] +name = "einops" +version = "0.8.1" +description = "A new flavour of deep learning operations" +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737"}, + {file = "einops-0.8.1.tar.gz", hash = "sha256:de5d960a7a761225532e0f1959e5315ebeafc0cd43394732f103ca44b9837e84"}, +] + [[package]] name = "filelock" version = "3.15.4" @@ -439,61 +548,68 @@ docs = ["matplotlib", "pydata-sphinx-theme", "sphinx (>=2.0)", "sphinx-gallery"] tests = ["pytest"] [[package]] -name = "idna" -version = "3.8" -description = "Internationalized Domain Names in Applications (IDNA)" +name = "huggingface-hub" +version = "0.30.2" +description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8.0" groups = ["main"] files = [ - {file = "idna-3.8-py3-none-any.whl", hash = "sha256:050b4e5baadcd44d760cedbd2b8e639f2ff89bbc7a5730fcc662954303377aac"}, - {file = "idna-3.8.tar.gz", hash = "sha256:d838c2c0ed6fced7693d5e8ab8e734d5f8fda53a039c0164afb0b82e771e3603"}, + {file = "huggingface_hub-0.30.2-py3-none-any.whl", hash = "sha256:68ff05969927058cfa41df4f2155d4bb48f5f54f719dd0390103eefa9b191e28"}, + {file = "huggingface_hub-0.30.2.tar.gz", hash = "sha256:9a7897c5b6fd9dad3168a794a8998d6378210f5b9688d0dfc180b1a228dc2466"}, ] +[package.dependencies] +filelock = "*" +fsspec = ">=2023.5.0" +packaging = ">=20.9" +pyyaml = ">=5.1" +requests = "*" +tqdm = ">=4.42.1" +typing-extensions = ">=3.7.4.3" + +[package.extras] +all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +cli = ["InquirerPy (==0.3.4)"] +dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "libcst (==1.4.0)", "mypy (==1.5.1)", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "ruff (>=0.9.0)", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)", "urllib3 (<2.0)"] +fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"] +hf-transfer = ["hf-transfer (>=0.1.4)"] +hf-xet = ["hf-xet (>=0.1.4)"] +inference = ["aiohttp"] +quality = ["libcst (==1.4.0)", "mypy (==1.5.1)", "ruff (>=0.9.0)"] +tensorflow = ["graphviz", "pydot", "tensorflow"] +tensorflow-testing = ["keras (<3.0)", "tensorflow"] +testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "aiohttp", "fastapi", "gradio (>=4.0.0)", "jedi", "numpy", "pytest (>=8.1.1,<8.2.2)", "pytest-asyncio", "pytest-cov", "pytest-env", "pytest-mock", "pytest-rerunfailures", "pytest-vcr", "pytest-xdist", "soundfile", "urllib3 (<2.0)"] +torch = ["safetensors[torch]", "torch"] +typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3", "typing-extensions (>=4.8.0)"] + [[package]] -name = "importlib-metadata" -version = "8.4.0" -description = "Read metadata from Python packages" +name = "idna" +version = "3.8" +description = "Internationalized Domain Names in Applications (IDNA)" optional = false -python-versions = ">=3.8" +python-versions = ">=3.6" groups = ["main"] -markers = "python_version < \"3.9\"" files = [ - {file = "importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1"}, - {file = "importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5"}, + {file = "idna-3.8-py3-none-any.whl", hash = "sha256:050b4e5baadcd44d760cedbd2b8e639f2ff89bbc7a5730fcc662954303377aac"}, + {file = "idna-3.8.tar.gz", hash = "sha256:d838c2c0ed6fced7693d5e8ab8e734d5f8fda53a039c0164afb0b82e771e3603"}, ] -[package.dependencies] -zipp = ">=0.5" - -[package.extras] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -perf = ["ipython"] -test = ["flufl.flake8", "importlib-resources (>=1.3) ; python_version < \"3.9\"", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-perf (>=0.9.2)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] - [[package]] -name = "importlib-resources" -version = "6.4.4" -description = "Read resources from Python packages" +name = "isort" +version = "6.0.1" +description = "A Python utility / library to sort Python imports." optional = false -python-versions = ">=3.8" +python-versions = ">=3.9.0" groups = ["main"] -markers = "python_version < \"3.9\"" files = [ - {file = "importlib_resources-6.4.4-py3-none-any.whl", hash = "sha256:dda242603d1c9cd836c3368b1174ed74cb4049ecd209e7a1a0104620c18c5c11"}, - {file = "importlib_resources-6.4.4.tar.gz", hash = "sha256:20600c8b7361938dc0bb2d5ec0297802e575df486f5a544fa414da65e13721f7"}, + {file = "isort-6.0.1-py3-none-any.whl", hash = "sha256:2dc5d7f65c9678d94c88dfc29161a320eec67328bc97aad576874cb4be1e9615"}, + {file = "isort-6.0.1.tar.gz", hash = "sha256:1cb5df28dfbc742e490c5e41bad6da41b805b0a8be7bc93cd0fb2a8a890ac450"}, ] -[package.dependencies] -zipp = {version = ">=3.1.0", markers = "python_version < \"3.10\""} - [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] -cover = ["pytest-cov"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -enabler = ["pytest-enabler (>=2.2)"] -test = ["jaraco.test (>=5.4)", "pytest (>=6,!=8.1.*)", "zipp (>=3.17)"] -type = ["pytest-mypy"] +colors = ["colorama"] +plugins = ["setuptools"] [[package]] name = "jinja2" @@ -633,6 +749,18 @@ docs = ["sphinx"] gmpy = ["gmpy2 (>=2.1.0a4) ; platform_python_implementation != \"PyPy\""] tests = ["pytest (>=4.6)"] +[[package]] +name = "mypy-extensions" +version = "1.1.0" +description = "Type system extensions for programs checked with the mypy type checker." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505"}, + {file = "mypy_extensions-1.1.0.tar.gz", hash = "sha256:52e68efc3284861e772bbcd66823fde5ae21fd2fdb51c62a211403730b916558"}, +] + [[package]] name = "networkx" version = "3.1" @@ -935,9 +1063,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, {version = ">=1.20.3", markers = "python_version < \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, - {version = ">=1.21.0", markers = "python_version == \"3.10\""}, ] python-dateutil = ">=2.8.1" pytz = ">=2020.1" @@ -945,6 +1073,18 @@ pytz = ">=2020.1" [package.extras] test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] +[[package]] +name = "pathspec" +version = "0.12.1" +description = "Utility library for gitignore style pattern matching of file paths." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08"}, + {file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"}, +] + [[package]] name = "patsy" version = "0.5.6" @@ -964,6 +1104,23 @@ six = "*" [package.extras] test = ["pytest", "pytest-cov", "scipy"] +[[package]] +name = "platformdirs" +version = "4.3.8" +description = "A small Python package for determining appropriate platform-specific dirs, e.g. a `user data dir`." +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "platformdirs-4.3.8-py3-none-any.whl", hash = "sha256:ff7059bb7eb1179e2685604f4aaf157cfd9535242bd23742eadc3c13542139b4"}, + {file = "platformdirs-4.3.8.tar.gz", hash = "sha256:3d512d96e16bcb959a814c9f348431070822a6496326a4be0911c40b5a74c2bc"}, +] + +[package.extras] +docs = ["furo (>=2024.8.6)", "proselint (>=0.14)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"] +test = ["appdirs (==1.4.4)", "covdefaults (>=2.3)", "pytest (>=8.3.4)", "pytest-cov (>=6)", "pytest-mock (>=3.14)"] +type = ["mypy (>=1.14.1)"] + [[package]] name = "pmdarima" version = "2.0.4" @@ -1147,6 +1304,110 @@ files = [ {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] +[[package]] +name = "regex" +version = "2024.11.6" +description = "Alternative regular expression module, to replace re." +optional = false +python-versions = ">=3.8" +groups = ["main"] +files = [ + {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff590880083d60acc0433f9c3f713c51f7ac6ebb9adf889c79a261ecf541aa91"}, + {file = "regex-2024.11.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:658f90550f38270639e83ce492f27d2c8d2cd63805c65a13a14d36ca126753f0"}, + {file = "regex-2024.11.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:164d8b7b3b4bcb2068b97428060b2a53be050085ef94eca7f240e7947f1b080e"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d3660c82f209655a06b587d55e723f0b813d3a7db2e32e5e7dc64ac2a9e86fde"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d22326fcdef5e08c154280b71163ced384b428343ae16a5ab2b3354aed12436e"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f1ac758ef6aebfc8943560194e9fd0fa18bcb34d89fd8bd2af18183afd8da3a2"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:997d6a487ff00807ba810e0f8332c18b4eb8d29463cfb7c820dc4b6e7562d0cf"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02a02d2bb04fec86ad61f3ea7f49c015a0681bf76abb9857f945d26159d2968c"}, + {file = "regex-2024.11.6-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f02f93b92358ee3f78660e43b4b0091229260c5d5c408d17d60bf26b6c900e86"}, + {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:06eb1be98df10e81ebaded73fcd51989dcf534e3c753466e4b60c4697a003b67"}, + {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:040df6fe1a5504eb0f04f048e6d09cd7c7110fef851d7c567a6b6e09942feb7d"}, + {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:fdabbfc59f2c6edba2a6622c647b716e34e8e3867e0ab975412c5c2f79b82da2"}, + {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8447d2d39b5abe381419319f942de20b7ecd60ce86f16a23b0698f22e1b70008"}, + {file = "regex-2024.11.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:da8f5fc57d1933de22a9e23eec290a0d8a5927a5370d24bda9a6abe50683fe62"}, + {file = "regex-2024.11.6-cp310-cp310-win32.whl", hash = "sha256:b489578720afb782f6ccf2840920f3a32e31ba28a4b162e13900c3e6bd3f930e"}, + {file = "regex-2024.11.6-cp310-cp310-win_amd64.whl", hash = "sha256:5071b2093e793357c9d8b2929dfc13ac5f0a6c650559503bb81189d0a3814519"}, + {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5478c6962ad548b54a591778e93cd7c456a7a29f8eca9c49e4f9a806dcc5d638"}, + {file = "regex-2024.11.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c89a8cc122b25ce6945f0423dc1352cb9593c68abd19223eebbd4e56612c5b7"}, + {file = "regex-2024.11.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:94d87b689cdd831934fa3ce16cc15cd65748e6d689f5d2b8f4f4df2065c9fa20"}, + {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1062b39a0a2b75a9c694f7a08e7183a80c63c0d62b301418ffd9c35f55aaa114"}, + {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:167ed4852351d8a750da48712c3930b031f6efdaa0f22fa1933716bfcd6bf4a3"}, + {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2d548dafee61f06ebdb584080621f3e0c23fff312f0de1afc776e2a2ba99a74f"}, + {file = "regex-2024.11.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f2a19f302cd1ce5dd01a9099aaa19cae6173306d1302a43b627f62e21cf18ac0"}, + {file = "regex-2024.11.6-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bec9931dfb61ddd8ef2ebc05646293812cb6b16b60cf7c9511a832b6f1854b55"}, + {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9714398225f299aa85267fd222f7142fcb5c769e73d7733344efc46f2ef5cf89"}, + {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:202eb32e89f60fc147a41e55cb086db2a3f8cb82f9a9a88440dcfc5d37faae8d"}, + {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:4181b814e56078e9b00427ca358ec44333765f5ca1b45597ec7446d3a1ef6e34"}, + {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:068376da5a7e4da51968ce4c122a7cd31afaaec4fccc7856c92f63876e57b51d"}, + {file = "regex-2024.11.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ac10f2c4184420d881a3475fb2c6f4d95d53a8d50209a2500723d831036f7c45"}, + {file = "regex-2024.11.6-cp311-cp311-win32.whl", hash = "sha256:c36f9b6f5f8649bb251a5f3f66564438977b7ef8386a52460ae77e6070d309d9"}, + {file = "regex-2024.11.6-cp311-cp311-win_amd64.whl", hash = "sha256:02e28184be537f0e75c1f9b2f8847dc51e08e6e171c6bde130b2687e0c33cf60"}, + {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:52fb28f528778f184f870b7cf8f225f5eef0a8f6e3778529bdd40c7b3920796a"}, + {file = "regex-2024.11.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fdd6028445d2460f33136c55eeb1f601ab06d74cb3347132e1c24250187500d9"}, + {file = "regex-2024.11.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:805e6b60c54bf766b251e94526ebad60b7de0c70f70a4e6210ee2891acb70bf2"}, + {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b85c2530be953a890eaffde05485238f07029600e8f098cdf1848d414a8b45e4"}, + {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bb26437975da7dc36b7efad18aa9dd4ea569d2357ae6b783bf1118dabd9ea577"}, + {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:abfa5080c374a76a251ba60683242bc17eeb2c9818d0d30117b4486be10c59d3"}, + {file = "regex-2024.11.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b7fa6606c2881c1db9479b0eaa11ed5dfa11c8d60a474ff0e095099f39d98e"}, + {file = "regex-2024.11.6-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0c32f75920cf99fe6b6c539c399a4a128452eaf1af27f39bce8909c9a3fd8cbe"}, + {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:982e6d21414e78e1f51cf595d7f321dcd14de1f2881c5dc6a6e23bbbbd68435e"}, + {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a7c2155f790e2fb448faed6dd241386719802296ec588a8b9051c1f5c481bc29"}, + {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:149f5008d286636e48cd0b1dd65018548944e495b0265b45e1bffecce1ef7f39"}, + {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:e5364a4502efca094731680e80009632ad6624084aff9a23ce8c8c6820de3e51"}, + {file = "regex-2024.11.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:0a86e7eeca091c09e021db8eb72d54751e527fa47b8d5787caf96d9831bd02ad"}, + {file = "regex-2024.11.6-cp312-cp312-win32.whl", hash = "sha256:32f9a4c643baad4efa81d549c2aadefaeba12249b2adc5af541759237eee1c54"}, + {file = "regex-2024.11.6-cp312-cp312-win_amd64.whl", hash = "sha256:a93c194e2df18f7d264092dc8539b8ffb86b45b899ab976aa15d48214138e81b"}, + {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a6ba92c0bcdf96cbf43a12c717eae4bc98325ca3730f6b130ffa2e3c3c723d84"}, + {file = "regex-2024.11.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:525eab0b789891ac3be914d36893bdf972d483fe66551f79d3e27146191a37d4"}, + {file = "regex-2024.11.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:086a27a0b4ca227941700e0b31425e7a28ef1ae8e5e05a33826e17e47fbfdba0"}, + {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bde01f35767c4a7899b7eb6e823b125a64de314a8ee9791367c9a34d56af18d0"}, + {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b583904576650166b3d920d2bcce13971f6f9e9a396c673187f49811b2769dc7"}, + {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1c4de13f06a0d54fa0d5ab1b7138bfa0d883220965a29616e3ea61b35d5f5fc7"}, + {file = "regex-2024.11.6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cde6e9f2580eb1665965ce9bf17ff4952f34f5b126beb509fee8f4e994f143c"}, + {file = "regex-2024.11.6-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d7f453dca13f40a02b79636a339c5b62b670141e63efd511d3f8f73fba162b3"}, + {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:59dfe1ed21aea057a65c6b586afd2a945de04fc7db3de0a6e3ed5397ad491b07"}, + {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:b97c1e0bd37c5cd7902e65f410779d39eeda155800b65fc4d04cc432efa9bc6e"}, + {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f9d1e379028e0fc2ae3654bac3cbbef81bf3fd571272a42d56c24007979bafb6"}, + {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:13291b39131e2d002a7940fb176e120bec5145f3aeb7621be6534e46251912c4"}, + {file = "regex-2024.11.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4f51f88c126370dcec4908576c5a627220da6c09d0bff31cfa89f2523843316d"}, + {file = "regex-2024.11.6-cp313-cp313-win32.whl", hash = "sha256:63b13cfd72e9601125027202cad74995ab26921d8cd935c25f09c630436348ff"}, + {file = "regex-2024.11.6-cp313-cp313-win_amd64.whl", hash = "sha256:2b3361af3198667e99927da8b84c1b010752fa4b1115ee30beaa332cabc3ef1a"}, + {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:3a51ccc315653ba012774efca4f23d1d2a8a8f278a6072e29c7147eee7da446b"}, + {file = "regex-2024.11.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ad182d02e40de7459b73155deb8996bbd8e96852267879396fb274e8700190e3"}, + {file = "regex-2024.11.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ba9b72e5643641b7d41fa1f6d5abda2c9a263ae835b917348fc3c928182ad467"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:40291b1b89ca6ad8d3f2b82782cc33807f1406cf68c8d440861da6304d8ffbbd"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cdf58d0e516ee426a48f7b2c03a332a4114420716d55769ff7108c37a09951bf"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a36fdf2af13c2b14738f6e973aba563623cb77d753bbbd8d414d18bfaa3105dd"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d1cee317bfc014c2419a76bcc87f071405e3966da434e03e13beb45f8aced1a6"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50153825ee016b91549962f970d6a4442fa106832e14c918acd1c8e479916c4f"}, + {file = "regex-2024.11.6-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ea1bfda2f7162605f6e8178223576856b3d791109f15ea99a9f95c16a7636fb5"}, + {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:df951c5f4a1b1910f1a99ff42c473ff60f8225baa1cdd3539fe2819d9543e9df"}, + {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:072623554418a9911446278f16ecb398fb3b540147a7828c06e2011fa531e773"}, + {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:f654882311409afb1d780b940234208a252322c24a93b442ca714d119e68086c"}, + {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:89d75e7293d2b3e674db7d4d9b1bee7f8f3d1609428e293771d1a962617150cc"}, + {file = "regex-2024.11.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:f65557897fc977a44ab205ea871b690adaef6b9da6afda4790a2484b04293a5f"}, + {file = "regex-2024.11.6-cp38-cp38-win32.whl", hash = "sha256:6f44ec28b1f858c98d3036ad5d7d0bfc568bdd7a74f9c24e25f41ef1ebfd81a4"}, + {file = "regex-2024.11.6-cp38-cp38-win_amd64.whl", hash = "sha256:bb8f74f2f10dbf13a0be8de623ba4f9491faf58c24064f32b65679b021ed0001"}, + {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5704e174f8ccab2026bd2f1ab6c510345ae8eac818b613d7d73e785f1310f839"}, + {file = "regex-2024.11.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:220902c3c5cc6af55d4fe19ead504de80eb91f786dc102fbd74894b1551f095e"}, + {file = "regex-2024.11.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5e7e351589da0850c125f1600a4c4ba3c722efefe16b297de54300f08d734fbf"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5056b185ca113c88e18223183aa1a50e66507769c9640a6ff75859619d73957b"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e34b51b650b23ed3354b5a07aab37034d9f923db2a40519139af34f485f77d0"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5670bce7b200273eee1840ef307bfa07cda90b38ae56e9a6ebcc9f50da9c469b"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:08986dce1339bc932923e7d1232ce9881499a0e02925f7402fb7c982515419ef"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:93c0b12d3d3bc25af4ebbf38f9ee780a487e8bf6954c115b9f015822d3bb8e48"}, + {file = "regex-2024.11.6-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:764e71f22ab3b305e7f4c21f1a97e1526a25ebdd22513e251cf376760213da13"}, + {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:f056bf21105c2515c32372bbc057f43eb02aae2fda61052e2f7622c801f0b4e2"}, + {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:69ab78f848845569401469da20df3e081e6b5a11cb086de3eed1d48f5ed57c95"}, + {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:86fddba590aad9208e2fa8b43b4c098bb0ec74f15718bb6a704e3c63e2cef3e9"}, + {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:684d7a212682996d21ca12ef3c17353c021fe9de6049e19ac8481ec35574a70f"}, + {file = "regex-2024.11.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:a03e02f48cd1abbd9f3b7e3586d97c8f7a9721c436f51a5245b3b9483044480b"}, + {file = "regex-2024.11.6-cp39-cp39-win32.whl", hash = "sha256:41758407fc32d5c3c5de163888068cfee69cb4c2be844e7ac517a52770f9af57"}, + {file = "regex-2024.11.6-cp39-cp39-win_amd64.whl", hash = "sha256:b2837718570f95dd41675328e111345f9b7095d821bac435aac173ac80b19983"}, + {file = "regex-2024.11.6.tar.gz", hash = "sha256:7ab159b063c52a0333c884e4679f8d7a85112ee3078fe3d9004b2dd875585519"}, +] + [[package]] name = "requests" version = "2.32.3" @@ -1169,6 +1430,44 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "safetensors" +version = "0.5.3" +description = "" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "safetensors-0.5.3-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd20eb133db8ed15b40110b7c00c6df51655a2998132193de2f75f72d99c7073"}, + {file = "safetensors-0.5.3-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:21d01c14ff6c415c485616b8b0bf961c46b3b343ca59110d38d744e577f9cce7"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:11bce6164887cd491ca75c2326a113ba934be596e22b28b1742ce27b1d076467"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4a243be3590bc3301c821da7a18d87224ef35cbd3e5f5727e4e0728b8172411e"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8bd84b12b1670a6f8e50f01e28156422a2bc07fb16fc4e98bded13039d688a0d"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:391ac8cab7c829452175f871fcaf414aa1e292b5448bd02620f675a7f3e7abb9"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cead1fa41fc54b1e61089fa57452e8834f798cb1dc7a09ba3524f1eb08e0317a"}, + {file = "safetensors-0.5.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1077f3e94182d72618357b04b5ced540ceb71c8a813d3319f1aba448e68a770d"}, + {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:799021e78287bac619c7b3f3606730a22da4cda27759ddf55d37c8db7511c74b"}, + {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:df26da01aaac504334644e1b7642fa000bfec820e7cef83aeac4e355e03195ff"}, + {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:32c3ef2d7af8b9f52ff685ed0bc43913cdcde135089ae322ee576de93eae5135"}, + {file = "safetensors-0.5.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:37f1521be045e56fc2b54c606d4455573e717b2d887c579ee1dbba5f868ece04"}, + {file = "safetensors-0.5.3-cp38-abi3-win32.whl", hash = "sha256:cfc0ec0846dcf6763b0ed3d1846ff36008c6e7290683b61616c4b040f6a54ace"}, + {file = "safetensors-0.5.3-cp38-abi3-win_amd64.whl", hash = "sha256:836cbbc320b47e80acd40e44c8682db0e8ad7123209f69b093def21ec7cafd11"}, + {file = "safetensors-0.5.3.tar.gz", hash = "sha256:b6b0d6ecacec39a4fdd99cc19f4576f5219ce858e6fd8dbe7609df0b8dc56965"}, +] + +[package.extras] +all = ["safetensors[jax]", "safetensors[numpy]", "safetensors[paddlepaddle]", "safetensors[pinned-tf]", "safetensors[quality]", "safetensors[testing]", "safetensors[torch]"] +dev = ["safetensors[all]"] +jax = ["flax (>=0.6.3)", "jax (>=0.3.25)", "jaxlib (>=0.3.25)", "safetensors[numpy]"] +mlx = ["mlx (>=0.0.9)"] +numpy = ["numpy (>=1.21.6)"] +paddlepaddle = ["paddlepaddle (>=2.4.1)", "safetensors[numpy]"] +pinned-tf = ["safetensors[numpy]", "tensorflow (==2.18.0)"] +quality = ["black (==22.3)", "click (==8.0.4)", "flake8 (>=3.8.3)", "isort (>=5.5.4)"] +tensorflow = ["safetensors[numpy]", "tensorflow (>=2.11.0)"] +testing = ["h5py (>=3.7.0)", "huggingface-hub (>=0.12.1)", "hypothesis (>=6.70.2)", "pytest (>=7.2.0)", "pytest-benchmark (>=4.0.0)", "safetensors[numpy]", "setuptools-rust (>=1.5.2)"] +torch = ["safetensors[numpy]", "torch (>=1.10)"] + [[package]] name = "scikit-base" version = "0.6.2" @@ -1439,6 +1738,35 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3_binary"] +[[package]] +name = "sqlalchemy-utils" +version = "0.41.2" +description = "Various utility functions for SQLAlchemy." +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "SQLAlchemy-Utils-0.41.2.tar.gz", hash = "sha256:bc599c8c3b3319e53ce6c5c3c471120bd325d0071fb6f38a10e924e3d07b9990"}, + {file = "SQLAlchemy_Utils-0.41.2-py3-none-any.whl", hash = "sha256:85cf3842da2bf060760f955f8467b87983fb2e30f1764fd0e24a48307dc8ec6e"}, +] + +[package.dependencies] +SQLAlchemy = ">=1.3" + +[package.extras] +arrow = ["arrow (>=0.3.4)"] +babel = ["Babel (>=1.3)"] +color = ["colour (>=0.0.4)"] +encrypted = ["cryptography (>=0.6)"] +intervals = ["intervals (>=0.7.1)"] +password = ["passlib (>=1.6,<2.0)"] +pendulum = ["pendulum (>=2.0.5)"] +phone = ["phonenumbers (>=5.9.2)"] +test = ["Jinja2 (>=2.3)", "Pygments (>=1.2)", "backports.zoneinfo ; python_version < \"3.9\"", "docutils (>=0.10)", "flake8 (>=2.4.0)", "flexmock (>=0.9.7)", "isort (>=4.2.2)", "pg8000 (>=1.12.4)", "psycopg (>=3.1.8)", "psycopg2 (>=2.5.1)", "psycopg2cffi (>=2.8.1)", "pymysql", "pyodbc", "pytest (==7.4.4)", "python-dateutil (>=2.6)", "pytz (>=2014.2)"] +test-all = ["Babel (>=1.3)", "Jinja2 (>=2.3)", "Pygments (>=1.2)", "arrow (>=0.3.4)", "backports.zoneinfo ; python_version < \"3.9\"", "colour (>=0.0.4)", "cryptography (>=0.6)", "docutils (>=0.10)", "flake8 (>=2.4.0)", "flexmock (>=0.9.7)", "furl (>=0.4.1)", "intervals (>=0.7.1)", "isort (>=4.2.2)", "passlib (>=1.6,<2.0)", "pendulum (>=2.0.5)", "pg8000 (>=1.12.4)", "phonenumbers (>=5.9.2)", "psycopg (>=3.1.8)", "psycopg2 (>=2.5.1)", "psycopg2cffi (>=2.8.1)", "pymysql", "pyodbc", "pytest (==7.4.4)", "python-dateutil", "python-dateutil (>=2.6)", "pytz (>=2014.2)"] +timezone = ["python-dateutil"] +url = ["furl (>=0.4.1)"] + [[package]] name = "statsmodels" version = "0.14.1" @@ -1480,8 +1808,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">=1.18,<2"}, {version = ">=1.22.3,<2", markers = "python_version == \"3.10\" and platform_system == \"Windows\" and platform_python_implementation != \"PyPy\""}, + {version = ">=1.18,<2", markers = "python_version != \"3.10\" or platform_system != \"Windows\" or platform_python_implementation == \"PyPy\""}, ] packaging = ">=21.3" pandas = ">=1.0,<2.1.0 || >2.1.0" @@ -1525,23 +1853,181 @@ files = [ [[package]] name = "thrift" -version = "0.13.0" +version = "0.22.0" description = "Python bindings for the Apache Thrift RPC system" optional = false python-versions = "*" groups = ["main"] files = [ - {file = "thrift-0.13.0.tar.gz", hash = "sha256:9af1c86bf73433afc6010ed376a6c6aca2b54099cc0d61895f640870a9ae7d89"}, + {file = "thrift-0.22.0.tar.gz", hash = "sha256:42e8276afbd5f54fe1d364858b6877bc5e5a4a5ed69f6a005b94ca4918fe1466"}, ] -[package.dependencies] -six = ">=1.7.2" - [package.extras] all = ["tornado (>=4.0)", "twisted"] tornado = ["tornado (>=4.0)"] twisted = ["twisted"] +[[package]] +name = "tokenizers" +version = "0.19.1" +description = "" +optional = false +python-versions = ">=3.7" +groups = ["main"] +files = [ + {file = "tokenizers-0.19.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:952078130b3d101e05ecfc7fc3640282d74ed26bcf691400f872563fca15ac97"}, + {file = "tokenizers-0.19.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:82c8b8063de6c0468f08e82c4e198763e7b97aabfe573fd4cf7b33930ca4df77"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f03727225feaf340ceeb7e00604825addef622d551cbd46b7b775ac834c1e1c4"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:453e4422efdfc9c6b6bf2eae00d5e323f263fff62b29a8c9cd526c5003f3f642"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:02e81bf089ebf0e7f4df34fa0207519f07e66d8491d963618252f2e0729e0b46"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b07c538ba956843833fee1190cf769c60dc62e1cf934ed50d77d5502194d63b1"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e28cab1582e0eec38b1f38c1c1fb2e56bce5dc180acb1724574fc5f47da2a4fe"}, + {file = "tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b01afb7193d47439f091cd8f070a1ced347ad0f9144952a30a41836902fe09e"}, + {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:7fb297edec6c6841ab2e4e8f357209519188e4a59b557ea4fafcf4691d1b4c98"}, + {file = "tokenizers-0.19.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2e8a3dd055e515df7054378dc9d6fa8c8c34e1f32777fb9a01fea81496b3f9d3"}, + {file = "tokenizers-0.19.1-cp310-none-win32.whl", hash = "sha256:7ff898780a155ea053f5d934925f3902be2ed1f4d916461e1a93019cc7250837"}, + {file = "tokenizers-0.19.1-cp310-none-win_amd64.whl", hash = "sha256:bea6f9947e9419c2fda21ae6c32871e3d398cba549b93f4a65a2d369662d9403"}, + {file = "tokenizers-0.19.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:5c88d1481f1882c2e53e6bb06491e474e420d9ac7bdff172610c4f9ad3898059"}, + {file = "tokenizers-0.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ddf672ed719b4ed82b51499100f5417d7d9f6fb05a65e232249268f35de5ed14"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:dadc509cc8a9fe460bd274c0e16ac4184d0958117cf026e0ea8b32b438171594"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfedf31824ca4915b511b03441784ff640378191918264268e6923da48104acc"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac11016d0a04aa6487b1513a3a36e7bee7eec0e5d30057c9c0408067345c48d2"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:76951121890fea8330d3a0df9a954b3f2a37e3ec20e5b0530e9a0044ca2e11fe"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b342d2ce8fc8d00f376af068e3274e2e8649562e3bc6ae4a67784ded6b99428d"}, + {file = "tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d16ff18907f4909dca9b076b9c2d899114dd6abceeb074eca0c93e2353f943aa"}, + {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:706a37cc5332f85f26efbe2bdc9ef8a9b372b77e4645331a405073e4b3a8c1c6"}, + {file = "tokenizers-0.19.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:16baac68651701364b0289979ecec728546133e8e8fe38f66fe48ad07996b88b"}, + {file = "tokenizers-0.19.1-cp311-none-win32.whl", hash = "sha256:9ed240c56b4403e22b9584ee37d87b8bfa14865134e3e1c3fb4b2c42fafd3256"}, + {file = "tokenizers-0.19.1-cp311-none-win_amd64.whl", hash = "sha256:ad57d59341710b94a7d9dbea13f5c1e7d76fd8d9bcd944a7a6ab0b0da6e0cc66"}, + {file = "tokenizers-0.19.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:621d670e1b1c281a1c9698ed89451395d318802ff88d1fc1accff0867a06f153"}, + {file = "tokenizers-0.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d924204a3dbe50b75630bd16f821ebda6a5f729928df30f582fb5aade90c818a"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4f3fefdc0446b1a1e6d81cd4c07088ac015665d2e812f6dbba4a06267d1a2c95"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9620b78e0b2d52ef07b0d428323fb34e8ea1219c5eac98c2596311f20f1f9266"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04ce49e82d100594715ac1b2ce87d1a36e61891a91de774755f743babcd0dd52"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5c2ff13d157afe413bf7e25789879dd463e5a4abfb529a2d8f8473d8042e28f"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3174c76efd9d08f836bfccaca7cfec3f4d1c0a4cf3acbc7236ad577cc423c840"}, + {file = "tokenizers-0.19.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c9d5b6c0e7a1e979bec10ff960fae925e947aab95619a6fdb4c1d8ff3708ce3"}, + {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:a179856d1caee06577220ebcfa332af046d576fb73454b8f4d4b0ba8324423ea"}, + {file = "tokenizers-0.19.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:952b80dac1a6492170f8c2429bd11fcaa14377e097d12a1dbe0ef2fb2241e16c"}, + {file = "tokenizers-0.19.1-cp312-none-win32.whl", hash = "sha256:01d62812454c188306755c94755465505836fd616f75067abcae529c35edeb57"}, + {file = "tokenizers-0.19.1-cp312-none-win_amd64.whl", hash = "sha256:b70bfbe3a82d3e3fb2a5e9b22a39f8d1740c96c68b6ace0086b39074f08ab89a"}, + {file = "tokenizers-0.19.1-cp37-cp37m-macosx_10_12_x86_64.whl", hash = "sha256:bb9dfe7dae85bc6119d705a76dc068c062b8b575abe3595e3c6276480e67e3f1"}, + {file = "tokenizers-0.19.1-cp37-cp37m-macosx_11_0_arm64.whl", hash = "sha256:1f0360cbea28ea99944ac089c00de7b2e3e1c58f479fb8613b6d8d511ce98267"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:71e3ec71f0e78780851fef28c2a9babe20270404c921b756d7c532d280349214"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b82931fa619dbad979c0ee8e54dd5278acc418209cc897e42fac041f5366d626"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e8ff5b90eabdcdaa19af697885f70fe0b714ce16709cf43d4952f1f85299e73a"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e742d76ad84acbdb1a8e4694f915fe59ff6edc381c97d6dfdd054954e3478ad4"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d8c5d59d7b59885eab559d5bc082b2985555a54cda04dda4c65528d90ad252ad"}, + {file = "tokenizers-0.19.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b2da5c32ed869bebd990c9420df49813709e953674c0722ff471a116d97b22d"}, + {file = "tokenizers-0.19.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:638e43936cc8b2cbb9f9d8dde0fe5e7e30766a3318d2342999ae27f68fdc9bd6"}, + {file = "tokenizers-0.19.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:78e769eb3b2c79687d9cb0f89ef77223e8e279b75c0a968e637ca7043a84463f"}, + {file = "tokenizers-0.19.1-cp37-none-win32.whl", hash = "sha256:72791f9bb1ca78e3ae525d4782e85272c63faaef9940d92142aa3eb79f3407a3"}, + {file = "tokenizers-0.19.1-cp37-none-win_amd64.whl", hash = "sha256:f3bbb7a0c5fcb692950b041ae11067ac54826204318922da754f908d95619fbc"}, + {file = "tokenizers-0.19.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:07f9295349bbbcedae8cefdbcfa7f686aa420be8aca5d4f7d1ae6016c128c0c5"}, + {file = "tokenizers-0.19.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:10a707cc6c4b6b183ec5dbfc5c34f3064e18cf62b4a938cb41699e33a99e03c1"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6309271f57b397aa0aff0cbbe632ca9d70430839ca3178bf0f06f825924eca22"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ad23d37d68cf00d54af184586d79b84075ada495e7c5c0f601f051b162112dc"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:427c4f0f3df9109314d4f75b8d1f65d9477033e67ffaec4bca53293d3aca286d"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e83a31c9cf181a0a3ef0abad2b5f6b43399faf5da7e696196ddd110d332519ee"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c27b99889bd58b7e301468c0838c5ed75e60c66df0d4db80c08f43462f82e0d3"}, + {file = "tokenizers-0.19.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bac0b0eb952412b0b196ca7a40e7dce4ed6f6926489313414010f2e6b9ec2adf"}, + {file = "tokenizers-0.19.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:8a6298bde623725ca31c9035a04bf2ef63208d266acd2bed8c2cb7d2b7d53ce6"}, + {file = "tokenizers-0.19.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:08a44864e42fa6d7d76d7be4bec62c9982f6f6248b4aa42f7302aa01e0abfd26"}, + {file = "tokenizers-0.19.1-cp38-none-win32.whl", hash = "sha256:1de5bc8652252d9357a666e609cb1453d4f8e160eb1fb2830ee369dd658e8975"}, + {file = "tokenizers-0.19.1-cp38-none-win_amd64.whl", hash = "sha256:0bcce02bf1ad9882345b34d5bd25ed4949a480cf0e656bbd468f4d8986f7a3f1"}, + {file = "tokenizers-0.19.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0b9394bd204842a2a1fd37fe29935353742be4a3460b6ccbaefa93f58a8df43d"}, + {file = "tokenizers-0.19.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4692ab92f91b87769d950ca14dbb61f8a9ef36a62f94bad6c82cc84a51f76f6a"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6258c2ef6f06259f70a682491c78561d492e885adeaf9f64f5389f78aa49a051"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c85cf76561fbd01e0d9ea2d1cbe711a65400092bc52b5242b16cfd22e51f0c58"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:670b802d4d82bbbb832ddb0d41df7015b3e549714c0e77f9bed3e74d42400fbe"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:85aa3ab4b03d5e99fdd31660872249df5e855334b6c333e0bc13032ff4469c4a"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cbf001afbbed111a79ca47d75941e9e5361297a87d186cbfc11ed45e30b5daba"}, + {file = "tokenizers-0.19.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b4c89aa46c269e4e70c4d4f9d6bc644fcc39bb409cb2a81227923404dd6f5227"}, + {file = "tokenizers-0.19.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:39c1ec76ea1027438fafe16ecb0fb84795e62e9d643444c1090179e63808c69d"}, + {file = "tokenizers-0.19.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:c2a0d47a89b48d7daa241e004e71fb5a50533718897a4cd6235cb846d511a478"}, + {file = "tokenizers-0.19.1-cp39-none-win32.whl", hash = "sha256:61b7fe8886f2e104d4caf9218b157b106207e0f2a4905c9c7ac98890688aabeb"}, + {file = "tokenizers-0.19.1-cp39-none-win_amd64.whl", hash = "sha256:f97660f6c43efd3e0bfd3f2e3e5615bf215680bad6ee3d469df6454b8c6e8256"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:3b11853f17b54c2fe47742c56d8a33bf49ce31caf531e87ac0d7d13d327c9334"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d26194ef6c13302f446d39972aaa36a1dda6450bc8949f5eb4c27f51191375bd"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:e8d1ed93beda54bbd6131a2cb363a576eac746d5c26ba5b7556bc6f964425594"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca407133536f19bdec44b3da117ef0d12e43f6d4b56ac4c765f37eca501c7bda"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce05fde79d2bc2e46ac08aacbc142bead21614d937aac950be88dc79f9db9022"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:35583cd46d16f07c054efd18b5d46af4a2f070a2dd0a47914e66f3ff5efb2b1e"}, + {file = "tokenizers-0.19.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:43350270bfc16b06ad3f6f07eab21f089adb835544417afda0f83256a8bf8b75"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b4399b59d1af5645bcee2072a463318114c39b8547437a7c2d6a186a1b5a0e2d"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6852c5b2a853b8b0ddc5993cd4f33bfffdca4fcc5d52f89dd4b8eada99379285"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bcd266ae85c3d39df2f7e7d0e07f6c41a55e9a3123bb11f854412952deacd828"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ecb2651956eea2aa0a2d099434134b1b68f1c31f9a5084d6d53f08ed43d45ff2"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:b279ab506ec4445166ac476fb4d3cc383accde1ea152998509a94d82547c8e2a"}, + {file = "tokenizers-0.19.1-pp37-pypy37_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:89183e55fb86e61d848ff83753f64cded119f5d6e1f553d14ffee3700d0a4a49"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:b2edbc75744235eea94d595a8b70fe279dd42f3296f76d5a86dde1d46e35f574"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-macosx_11_0_arm64.whl", hash = "sha256:0e64bfde9a723274e9a71630c3e9494ed7b4c0f76a1faacf7fe294cd26f7ae7c"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:0b5ca92bfa717759c052e345770792d02d1f43b06f9e790ca0a1db62838816f3"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f8a20266e695ec9d7a946a019c1d5ca4eddb6613d4f466888eee04f16eedb85"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63c38f45d8f2a2ec0f3a20073cccb335b9f99f73b3c69483cd52ebc75369d8a1"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:dd26e3afe8a7b61422df3176e06664503d3f5973b94f45d5c45987e1cb711876"}, + {file = "tokenizers-0.19.1-pp38-pypy38_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:eddd5783a4a6309ce23432353cdb36220e25cbb779bfa9122320666508b44b88"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:56ae39d4036b753994476a1b935584071093b55c7a72e3b8288e68c313ca26e7"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:f9939ca7e58c2758c01b40324a59c034ce0cebad18e0d4563a9b1beab3018243"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:6c330c0eb815d212893c67a032e9dc1b38a803eccb32f3e8172c19cc69fbb439"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ec11802450a2487cdf0e634b750a04cbdc1c4d066b97d94ce7dd2cb51ebb325b"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2b718f316b596f36e1dae097a7d5b91fc5b85e90bf08b01ff139bd8953b25af"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:ed69af290c2b65169f0ba9034d1dc39a5db9459b32f1dd8b5f3f32a3fcf06eab"}, + {file = "tokenizers-0.19.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f8a9c828277133af13f3859d1b6bf1c3cb6e9e1637df0e45312e6b7c2e622b1f"}, + {file = "tokenizers-0.19.1.tar.gz", hash = "sha256:ee59e6680ed0fdbe6b724cf38bd70400a0c1dd623b07ac729087270caeac88e3"}, +] + +[package.dependencies] +huggingface-hub = ">=0.16.4,<1.0" + +[package.extras] +dev = ["tokenizers[testing]"] +docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"] +testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests", "ruff"] + +[[package]] +name = "tomli" +version = "2.2.1" +description = "A lil' TOML parser" +optional = false +python-versions = ">=3.8" +groups = ["main"] +markers = "python_version < \"3.11\"" +files = [ + {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"}, + {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"}, + {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"}, + {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"}, + {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"}, + {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"}, + {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"}, + {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"}, + {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, +] + [[package]] name = "torch" version = "2.2.0" @@ -1622,6 +2108,75 @@ notebook = ["ipywidgets (>=6)"] slack = ["slack-sdk"] telegram = ["requests"] +[[package]] +name = "transformers" +version = "4.40.1" +description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" +optional = false +python-versions = ">=3.8.0" +groups = ["main"] +files = [ + {file = "transformers-4.40.1-py3-none-any.whl", hash = "sha256:9d5ee0c8142a60501faf9e49a0b42f8e9cb8611823bce4f195a9325a6816337e"}, + {file = "transformers-4.40.1.tar.gz", hash = "sha256:55e1697e6f18b58273e7117bb469cdffc11be28995462d8d5e422fef38d2de36"}, +] + +[package.dependencies] +filelock = "*" +huggingface-hub = ">=0.19.3,<1.0" +numpy = ">=1.17" +packaging = ">=20.0" +pyyaml = ">=5.1" +regex = "!=2019.12.17" +requests = "*" +safetensors = ">=0.4.1" +tokenizers = ">=0.19,<0.20" +tqdm = ">=4.27" + +[package.extras] +accelerate = ["accelerate (>=0.21.0)"] +agents = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "datasets (!=2.5.0)", "diffusers", "opencv-python", "sentencepiece (>=0.1.91,!=0.1.92)", "torch"] +all = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] +audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +codecarbon = ["codecarbon (==1.2.0)"] +deepspeed = ["accelerate (>=0.21.0)", "deepspeed (>=0.9.3)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.21.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.19,<0.20)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "beautifulsoup4", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "ray[tune] (>=2.7.0)", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +docs = ["Pillow (>=10.0.1,<=15.0)", "accelerate (>=0.21.0)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune] (>=2.7.0)", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx", "timm", "tokenizers (>=0.19,<0.20)", "torch", "torchaudio", "torchvision"] +docs-specific = ["hf-doc-builder"] +flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] +flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +ftfy = ["ftfy"] +integrations = ["optuna", "ray[tune] (>=2.7.0)", "sigopt"] +ja = ["fugashi (>=1.0)", "ipadic (>=1.0.0,<2.0)", "rhoknp (>=1.1.0,<1.3.1)", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)"] +modelcreation = ["cookiecutter (==1.7.3)"] +natten = ["natten (>=0.14.6,<0.15.0)"] +onnx = ["onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "tf2onnx"] +onnxruntime = ["onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)"] +optuna = ["optuna"] +quality = ["GitPython (<3.1.19)", "datasets (!=2.5.0)", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "ruff (==0.1.5)", "urllib3 (<2.0.0)"] +ray = ["ray[tune] (>=2.7.0)"] +retrieval = ["datasets (!=2.5.0)", "faiss-cpu"] +sagemaker = ["sagemaker (>=2.31.0)"] +sentencepiece = ["protobuf", "sentencepiece (>=0.1.91,!=0.1.92)"] +serving = ["fastapi", "pydantic", "starlette", "uvicorn"] +sigopt = ["sigopt"] +sklearn = ["scikit-learn"] +speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pydantic", "pytest (>=7.2.0,<8.0.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (==0.1.5)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.16)", "tensorflow-text (<2.16)", "tf2onnx"] +tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] +timm = ["timm"] +tokenizers = ["tokenizers (>=0.19,<0.20)"] +torch = ["accelerate (>=0.21.0)", "torch"] +torch-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] +torch-vision = ["Pillow (>=10.0.1,<=15.0)", "torchvision"] +torchhub = ["filelock", "huggingface-hub (>=0.19.3,<1.0)", "importlib-metadata", "numpy (>=1.17)", "packaging (>=20.0)", "protobuf", "regex (!=2019.12.17)", "requests", "sentencepiece (>=0.1.91,!=0.1.92)", "tokenizers (>=0.19,<0.20)", "torch", "tqdm (>=4.27)"] +video = ["av (==9.2.0)", "decord (==0.6.0)"] +vision = ["Pillow (>=10.0.1,<=15.0)"] + [[package]] name = "triton" version = "2.2.0" @@ -1659,6 +2214,37 @@ files = [ {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] +[[package]] +name = "tzdata" +version = "2025.2" +description = "Provider of IANA time zone data" +optional = false +python-versions = ">=2" +groups = ["main"] +markers = "platform_system == \"Windows\"" +files = [ + {file = "tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8"}, + {file = "tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9"}, +] + +[[package]] +name = "tzlocal" +version = "5.3.1" +description = "tzinfo object for the local timezone" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "tzlocal-5.3.1-py3-none-any.whl", hash = "sha256:eb1a66c3ef5847adf7a834f1be0800581b683b5608e74f86ecbcef8ab91bb85d"}, + {file = "tzlocal-5.3.1.tar.gz", hash = "sha256:cceffc7edecefea1f595541dbd6e990cb1ea3d19bf01b2809f362a03dd7921fd"}, +] + +[package.dependencies] +tzdata = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +devenv = ["check-manifest", "pytest (>=4.3)", "pytest-cov", "pytest-mock (>=3.3)", "zest.releaser"] + [[package]] name = "urllib3" version = "2.2.2" @@ -1677,28 +2263,7 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] -[[package]] -name = "zipp" -version = "3.20.1" -description = "Backport of pathlib-compatible object wrapper for zip files" -optional = false -python-versions = ">=3.8" -groups = ["main"] -markers = "python_version < \"3.9\"" -files = [ - {file = "zipp-3.20.1-py3-none-any.whl", hash = "sha256:9960cd8967c8f85a56f920d5d507274e74f9ff813a0ab8889a5b5be2daf44064"}, - {file = "zipp-3.20.1.tar.gz", hash = "sha256:c22b14cc4763c5a5b04134207736c107db42e9d3ef2d9779d465f5f1bcba572b"}, -] - -[package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] -cover = ["pytest-cov"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -enabler = ["pytest-enabler (>=2.2)"] -test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] -type = ["pytest-mypy"] - [metadata] lock-version = "2.1" -python-versions = ">=3.8, <3.13" -content-hash = "f8ff25befae83d79c99b9eb13009b72fcbb717da26800d179eaf807f19a747f7" +python-versions = ">=3.9, <3.13" +content-hash = "ef32360fe61470ec51e88ab26dad5596dba4e3c088e80457c37f515cb4e03b2f" diff --git a/iotdb-core/ainode/pyproject.toml b/iotdb-core/ainode/pyproject.toml index eaa418b623f95..e8cacb42a9f77 100644 --- a/iotdb-core/ainode/pyproject.toml +++ b/iotdb-core/ainode/pyproject.toml @@ -21,7 +21,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "apache-iotdb-ainode" -version = "2.0.0.dev" +version = "2.0.4.dev" description = "Apache IoTDB AINode" readme = "README.md" authors = ["Apache Software Foundation "] @@ -46,7 +46,7 @@ packages = [ ] [tool.poetry.dependencies] -python = ">=3.8, <3.13" +python = ">=3.9, <3.13" numpy = "^1.21.4" pandas = "^1.3.5" torch = ">=2.2.0" @@ -63,7 +63,12 @@ apache-iotdb = "2.0.4.dev0" einops = "^0.8.1" safetensors = "^0.5.1" huggingface_hub = "^0.30.1" +black = "25.1.0" +isort = "6.0.1" transformers = "==4.40.1" [tool.poetry.scripts] ainode = "ainode.core.script:main" + +[tool.isort] +profile = "black" \ No newline at end of file