Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions .github/workflows/greeting-ainode.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
name: AINode Code Style Check

on:
push:
branches:
- master
- "rc/*"
paths:
- 'iotdb-core/ainode/**'
pull_request:
branches:
- master
- "rc/*"
paths:
- 'iotdb-core/ainode/**'
# allow manually run the action:
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

env:
MAVEN_OPTS: -Dhttp.keepAlive=false -Dmaven.wagon.http.pool=false -Dmaven.wagon.http.retryHandler.class=standard -Dmaven.wagon.http.retryHandler.count=3
MAVEN_ARGS: --batch-mode --no-transfer-progress

jobs:
check-style:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: "3.10"

- name: Install dependencies
run: |
pip3 install black==25.1.0 isort==6.0.1
- name: Check code formatting (Black)
run: |
cd iotdb-core/ainode
black --check .
continue-on-error: false

- name: Check import order (Isort)
run: |
cd iotdb-core/ainode
isort --check-only --profile black .
continue-on-error: false
2 changes: 1 addition & 1 deletion iotdb-core/ainode/ainode/TimerXL/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#
43 changes: 28 additions & 15 deletions iotdb-core/ainode/ainode/TimerXL/layers/Attn_Bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#
import abc
import math

import torch
from einops import rearrange
from torch import nn
Expand All @@ -41,22 +42,23 @@ def __init__(self, dim: int, num_heads: int):

def forward(self, query_id, kv_id):
ind = torch.eq(query_id.unsqueeze(-1), kv_id.unsqueeze(-2))
weight = rearrange(
self.emb.weight, "two num_heads -> two num_heads 1 1")
weight = rearrange(self.emb.weight, "two num_heads -> two num_heads 1 1")
bias = ~ind * weight[:1] + ind * weight[1:]
return bias


def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
def _relative_position_bucket(
relative_position, bidirectional=True, num_buckets=32, max_distance=128
):
relative_buckets = 0
if bidirectional:
num_buckets //= 2
relative_buckets += (relative_position >
0).to(torch.long) * num_buckets
relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
relative_position = torch.abs(relative_position)
else:
relative_position = - \
torch.min(relative_position, torch.zeros_like(relative_position))
relative_position = -torch.min(
relative_position, torch.zeros_like(relative_position)
)

max_exact = num_buckets // 2
is_small = relative_position < max_exact
Expand All @@ -66,12 +68,13 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets
* (num_buckets - max_exact)
).to(torch.long)
relative_position_if_large = torch.min(
relative_position_if_large, torch.full_like(
relative_position_if_large, num_buckets - 1)
relative_position_if_large,
torch.full_like(relative_position_if_large, num_buckets - 1),
)

relative_buckets += torch.where(is_small,
relative_position, relative_position_if_large)
relative_buckets += torch.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets


Expand All @@ -83,11 +86,21 @@ def __init__(self, dim: int, num_heads: int):
self.relative_attention_bias = nn.Embedding(self.num_buckets, 1)

def forward(self, n_vars, n_tokens):
context_position = torch.arange(n_tokens, dtype=torch.long,)[:, None]
memory_position = torch.arange(n_tokens, dtype=torch.long, )[None, :]
context_position = torch.arange(
n_tokens,
dtype=torch.long,
)[:, None]
memory_position = torch.arange(
n_tokens,
dtype=torch.long,
)[None, :]
relative_position = memory_position - context_position
bucket = _relative_position_bucket(relative_position=relative_position, bidirectional=False,
num_buckets=self.num_buckets, max_distance=self.max_distance).to(self.relative_attention_bias.weight.device)
bucket = _relative_position_bucket(
relative_position=relative_position,
bidirectional=False,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
).to(self.relative_attention_bias.weight.device)
bias = self.relative_attention_bias(bucket).squeeze(-1)
bias = bias.reshape(1, 1, bias.shape[0], bias.shape[1])
mask1 = torch.ones((n_vars, n_vars), dtype=torch.bool).to(bias.device)
Expand Down
14 changes: 9 additions & 5 deletions iotdb-core/ainode/ainode/TimerXL/layers/Attn_Projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
#
import abc
import torch
from functools import cached_property

import torch
from einops import einsum, rearrange, repeat
from torch import nn

Expand All @@ -33,7 +34,9 @@ def forward(self, x, seq_id): ...


class RotaryProjection(Projection):
def __init__(self, *, proj_width: int, num_heads: int, max_len: int = 512, base: int = 10000):
def __init__(
self, *, proj_width: int, num_heads: int, max_len: int = 512, base: int = 10000
):
super().__init__(proj_width, num_heads)
assert (
self.proj_width % 2 == 0
Expand All @@ -57,8 +60,7 @@ def _init_freq(self, max_len: int):
position = torch.arange(
max_len, device=self.theta.device, dtype=self.theta.dtype
)
m_theta = einsum(position, self.theta,
"length, width -> length width")
m_theta = einsum(position, self.theta, "length, width -> length width")
m_theta = repeat(m_theta, "length width -> length (width 2)")
self.register_buffer("cos", torch.cos(m_theta), persistent=False)
self.register_buffer("sin", torch.sin(m_theta), persistent=False)
Expand All @@ -76,7 +78,9 @@ def forward(self, x, seq_id):


class QueryKeyProjection(nn.Module):
def __init__(self, dim: int, num_heads: int, proj_layer, kwargs=None, partial_factor=None):
def __init__(
self, dim: int, num_heads: int, proj_layer, kwargs=None, partial_factor=None
):
super().__init__()
if partial_factor is not None:
assert (
Expand Down
Loading
Loading