diff --git a/.pylintrc b/.pylintrc
index 4e65a5b3..d3f932dc 100644
--- a/.pylintrc
+++ b/.pylintrc
@@ -12,7 +12,7 @@ ignore=CVS
# Add files or directories matching the regex patterns to the blacklist.
# The regex matches against base names, not paths.
-ignore-patterns=
+ignore-patterns=test_transformer_blocks
# Python code to execute, usually for sys.path manipulation
# such as pygtk.require().
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 070e44c9..3683c370 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
+## [Unreleased]
+
+### Added
+- `rectools.semantic` package with TIGER generative recommender components, including model, Lightning integration, loss, tokenizer modules, semantic metrics, and data handling utilities
+- Semantic building blocks for neural recommender workflows, including transformer blocks, MLP modules, residual k-means, and RQ-VAE tokenizer implementations
+- TIGER tutorial notebook, SASRec vs TIGER benchmark script, and semantic test coverage across data handling, modules, tokenizer, metrics, and TIGER components
+
## [Unreleased]
diff --git a/benchmark/compare_sasrec_tiger.py b/benchmark/compare_sasrec_tiger.py
new file mode 100644
index 00000000..0045503e
--- /dev/null
+++ b/benchmark/compare_sasrec_tiger.py
@@ -0,0 +1,810 @@
+"""Compare RecTools SASRecModel and semantic TIGER on ML-20M.
+
+Both models are trained on MovieLens 20M data and the benchmark writes a
+Markdown report to ``benchmark/comparison_sasrec_vs_tiger.md``.
+
+Usage:
+```bash
+python3 -m benchmark.compare_sasrec_tiger
+```
+
+The script downloads ML-20M automatically if it is not present locally.
+Movie metadata is taken from ``movies.csv`` and embedded with
+``Qwen/Qwen3-Embedding-0.6B`` for the TIGER tokenizer.
+"""
+
+# pylint: disable=too-many-locals,too-many-statements,import-outside-toplevel
+# pylint: disable=import-error,unsubscriptable-object,protected-access,cyclic-import
+
+import argparse
+import gc
+import shutil
+import time
+import zipfile
+from datetime import datetime
+from pathlib import Path
+from typing import Optional
+from urllib.request import urlretrieve
+
+import numpy as np
+import pandas as pd
+import pytorch_lightning as pl
+import torch
+from pytorch_lightning import Trainer
+from pytorch_lightning.callbacks.early_stopping import EarlyStopping
+from sentence_transformers import SentenceTransformer
+from tqdm.auto import trange
+
+from rectools import Columns
+from rectools.dataset import Dataset
+from rectools.metrics import MRR, NDCG, CatalogCoverage, HitRate
+from rectools.models import SASRecModel
+from rectools.models.nn.item_net import IdEmbeddingsItemNet
+from rectools.models.nn.transformers.utils import leave_one_out_mask
+from rectools.semantic.data_handling import k_core, loo_split
+from rectools.semantic.metrics import gini_k
+from rectools.semantic.tiger import TIGERModel
+from rectools.semantic.tokenizer import SIDTokenizer
+
+BENCHMARK_DIR = Path(__file__).resolve().parent
+DEFAULT_WORKDIR = BENCHMARK_DIR / "data" / "ml-20m"
+DEFAULT_REPORT_PATH = BENCHMARK_DIR / "comparison_sasrec_vs_tiger.md"
+ML20M_URL = "https://files.grouplens.org/datasets/movielens/ml-20m.zip"
+
+
+def parse_args() -> argparse.Namespace:
+ parser = argparse.ArgumentParser(
+ description="Benchmark RecTools SASRec against semantic TIGER on MovieLens 20M."
+ )
+ parser.add_argument(
+ "--workdir",
+ type=Path,
+ default=DEFAULT_WORKDIR,
+ help="Directory for downloaded data and cached artifacts.",
+ )
+ parser.add_argument(
+ "--report-path",
+ type=Path,
+ default=DEFAULT_REPORT_PATH,
+ help="Path to the Markdown report produced by the benchmark.",
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default=None,
+ help="Torch device override. Defaults to cuda -> mps -> cpu.",
+ )
+ parser.add_argument(
+ "--top-k-main",
+ type=int,
+ default=10,
+ help="K for NDCG, MRR, coverage, and gini.",
+ )
+ parser.add_argument(
+ "--max-length", type=int, default=200, help="Maximum context length."
+ )
+ parser.add_argument("--top-k-hit", type=int, default=10, help="K for hit rate.")
+ parser.add_argument("--seed", type=int, default=42, help="Random seed.")
+ parser.add_argument(
+ "--limit-users",
+ type=int,
+ default=None,
+ help="Optional cap on users after preprocessing.",
+ )
+ parser.add_argument(
+ "--min-rating",
+ type=float,
+ default=-1.0,
+ help="Optional minimum rating filter. Use -1 to keep all ratings.",
+ )
+ parser.add_argument(
+ "--min-item-interactions",
+ type=int,
+ default=5,
+ help="Optional minimum number of interactions per item.",
+ )
+ parser.add_argument(
+ "--min-user-interactions",
+ type=int,
+ default=2,
+ help="Optional minimum number of interactions per user.",
+ )
+ parser.add_argument(
+ "--embedding-model",
+ type=str,
+ default="Qwen/Qwen3-Embedding-0.6B",
+ help="SentenceTransformer model used to embed movie metadata.",
+ )
+ parser.add_argument(
+ "--tokenizer-max-epochs",
+ type=int,
+ default=1000,
+ help="Tokenizer training epochs.",
+ )
+ parser.add_argument(
+ "--tokenizer-batch-size",
+ type=int,
+ default=256,
+ help="Batch size for metadata embedding and tokenizer training.",
+ )
+ parser.add_argument(
+ "--tokenizer-patience",
+ type=int,
+ default=100,
+ help="Tokenizer early stopping patience.",
+ )
+ parser.add_argument(
+ "--sasrec-epochs", type=int, default=10, help="Maximum SASRec epochs."
+ )
+ parser.add_argument(
+ "--sasrec-patience",
+ type=int,
+ default=10,
+ help="SASRec early stopping patience.",
+ )
+ parser.add_argument(
+ "--sasrec-batch-size", type=int, default=256, help="SASRec batch size."
+ )
+ parser.add_argument(
+ "--sasrec-num-workers", type=int, default=4, help="SASRec dataloader workers."
+ )
+ parser.add_argument(
+ "--tiger-epochs", type=int, default=10, help="Maximum TIGER epochs."
+ )
+ parser.add_argument(
+ "--tiger-patience", type=int, default=10, help="TIGER early stopping patience."
+ )
+ parser.add_argument(
+ "--tiger-batch-size", type=int, default=256, help="TIGER train batch size."
+ )
+ parser.add_argument(
+ "--tiger-eval-batch-size", type=int, default=256, help="TIGER eval batch size."
+ )
+ parser.add_argument(
+ "--tiger-num-workers", type=int, default=4, help="TIGER dataloader workers."
+ )
+ return parser.parse_args()
+
+
+def resolve_device(device: Optional[str]) -> str:
+ if device is not None:
+ return device
+ if torch.cuda.is_available():
+ return "cuda"
+ if (
+ getattr(torch.backends, "mps", None) is not None
+ and torch.backends.mps.is_available()
+ ):
+ return "mps"
+ return "cpu"
+
+
+def describe_device(device: str) -> str:
+ if device.startswith("cuda") and torch.cuda.is_available():
+ gpu_index = torch.device(device).index or 0
+ return f"{device} ({torch.cuda.get_device_name(gpu_index)})"
+ if device == "mps":
+ return "mps (Apple Metal)"
+ return "cpu"
+
+
+def maybe_download(url: str, dst: Path) -> None:
+ if dst.exists():
+ return
+ dst.parent.mkdir(parents=True, exist_ok=True)
+ print(f"Downloading {url} -> {dst}")
+ urlretrieve(url, dst) # nosec B310
+
+
+def download_ml20m(workdir: Path) -> tuple[Path, Path]:
+ ratings_path = workdir / "ratings.csv"
+ movies_path = workdir / "movies.csv"
+ if ratings_path.exists() and movies_path.exists():
+ return ratings_path, movies_path
+
+ archive_path = workdir / "ml-20m.zip"
+ maybe_download(ML20M_URL, archive_path)
+
+ print(f"Extracting {archive_path}")
+ with zipfile.ZipFile(archive_path) as archive:
+ for member in archive.namelist():
+ basename = Path(member).name
+ if basename not in {
+ "ratings.csv",
+ "movies.csv",
+ "README.txt",
+ "links.csv",
+ "tags.csv",
+ }:
+ continue
+ target = workdir / basename
+ if target.exists():
+ continue
+ with archive.open(member) as src, open(target, "wb") as dst_stream:
+ shutil.copyfileobj(src, dst_stream)
+
+ return ratings_path, movies_path
+
+
+def load_interactions(ratings_path: Path, args: argparse.Namespace) -> pd.DataFrame:
+ interactions = pd.read_csv(ratings_path)
+ interactions.columns = ["user_id", "item_id", "rating", "timestamp"]
+
+ if args.min_rating > 0:
+ interactions = interactions[interactions["rating"] >= args.min_rating]
+
+ use_k_core = args.min_item_interactions > 0 or args.min_user_interactions > 0
+ if use_k_core:
+ print(
+ f"Performing K-Core filtering (user>={args.min_user_interactions}, item>={args.min_item_interactions})..."
+ )
+ interactions = k_core(
+ interactions,
+ user_min_interactions=max(args.min_user_interactions, 0),
+ item_min_interactions=max(args.min_item_interactions, 0),
+ )
+
+ if args.limit_users is not None:
+ kept_users = interactions["user_id"].drop_duplicates().iloc[: args.limit_users]
+ interactions = interactions[
+ interactions["user_id"].isin(kept_users)
+ ].reset_index(drop=True)
+ if use_k_core:
+ interactions = k_core(
+ interactions,
+ user_min_interactions=max(args.min_user_interactions, 0),
+ item_min_interactions=max(args.min_item_interactions, 0),
+ )
+
+ interactions = interactions.sort_values(
+ ["user_id", "timestamp", "item_id"]
+ ).reset_index(drop=True)
+ interactions["timestamp"] = pd.to_datetime(interactions["timestamp"], unit="s")
+
+ return interactions
+
+
+def save_preprocessed_files(
+ workdir: Path,
+ interactions: pd.DataFrame,
+ meta: pd.DataFrame,
+ train_df: pd.DataFrame,
+ val_df: pd.DataFrame,
+ test_df: pd.DataFrame,
+ test_targets: pd.DataFrame,
+) -> dict[str, Path]:
+ preprocessed_dir = workdir / "preprocessed"
+ preprocessed_dir.mkdir(parents=True, exist_ok=True)
+
+ paths = {
+ "interactions": preprocessed_dir / "interactions.csv",
+ "item_metadata": preprocessed_dir / "item_metadata.csv",
+ "train": preprocessed_dir / "train.csv",
+ "val": preprocessed_dir / "val.csv",
+ "test": preprocessed_dir / "test.csv",
+ "test_targets": preprocessed_dir / "test_targets.csv",
+ }
+ interactions.to_csv(paths["interactions"], index=False)
+ meta.to_csv(paths["item_metadata"], index=False)
+ train_df.to_csv(paths["train"], index=False)
+ val_df.to_csv(paths["val"], index=False)
+ test_df.to_csv(paths["test"], index=False)
+ test_targets.to_csv(paths["test_targets"], index=False)
+
+ return paths
+
+
+def build_meta(movies_path: Path, interactions: pd.DataFrame) -> pd.DataFrame:
+ movies = pd.read_csv(movies_path)
+ movies = movies.rename(
+ columns={"movieId": "item_id", "title": "title", "genres": "genres"}
+ )
+ item_ids = set(interactions["item_id"].unique())
+
+ meta = movies[movies["item_id"].isin(item_ids)].copy()
+ meta["title"] = meta["title"].fillna("Unknown").astype(str)
+ meta["genres"] = meta["genres"].fillna("Unknown").astype(str)
+ meta["text"] = meta["title"] + " Genres: " + meta["genres"]
+ meta = meta[["item_id", "text"]].drop_duplicates("item_id")
+
+ missing_items = sorted(item_ids - set(meta["item_id"]))
+ if missing_items:
+ missing_meta = pd.DataFrame(
+ {
+ "item_id": missing_items,
+ "text": ["Unknown title. Genres: Unknown"] * len(missing_items),
+ }
+ )
+ meta = pd.concat([meta, missing_meta], ignore_index=True)
+
+ return meta.sort_values("item_id").reset_index(drop=True)
+
+
+def map_ids(
+ interactions: pd.DataFrame, meta: pd.DataFrame
+) -> tuple[pd.DataFrame, pd.DataFrame]:
+ user_codes = pd.Index(interactions["user_id"].unique())
+ item_codes = pd.Index(interactions["item_id"].unique())
+
+ user_map = pd.Series(np.arange(len(user_codes), dtype=np.int64), index=user_codes)
+ item_map = pd.Series(np.arange(len(item_codes), dtype=np.int64), index=item_codes)
+
+ mapped_interactions = interactions.copy()
+ mapped_interactions["user_id"] = (
+ mapped_interactions["user_id"].map(user_map).astype(np.int64)
+ )
+ mapped_interactions["item_id"] = (
+ mapped_interactions["item_id"].map(item_map).astype(np.int64)
+ )
+
+ mapped_meta = meta[meta["item_id"].isin(item_map.index)].copy()
+ mapped_meta["item_id"] = mapped_meta["item_id"].map(item_map).astype(np.int64)
+ mapped_meta = mapped_meta.sort_values("item_id").reset_index(drop=True)
+
+ return mapped_interactions, mapped_meta
+
+
+def encode_texts(
+ texts: list[str], model_name: str, device: str, batch_size: int
+) -> np.ndarray:
+ embeds = []
+ model = SentenceTransformer(model_name, device=device).eval()
+ with torch.no_grad():
+ for start in trange(0, len(texts), batch_size, desc="Embedding movie metadata"):
+ end = min(start + batch_size, len(texts))
+ embeds.append(model.encode(texts[start:end], show_progress_bar=False))
+ return np.vstack(embeds)
+
+
+def make_sasrec_dataset(val_df: pd.DataFrame) -> Dataset:
+ sasrec_df = val_df.rename(columns={"timestamp": Columns.Datetime}).copy()
+ sasrec_df[Columns.Weight] = 1.0
+ return Dataset.construct(
+ interactions_df=sasrec_df[
+ [Columns.User, Columns.Item, Columns.Weight, Columns.Datetime]
+ ]
+ )
+
+
+def get_last_item_targets(test_df: pd.DataFrame) -> pd.DataFrame:
+ targets = (
+ test_df.sort_values(["user_id", "timestamp"])
+ .groupby("user_id", sort=False)
+ .tail(1)
+ .copy()
+ )
+ targets["weight"] = 1.0
+ return targets[["user_id", "item_id", "weight"]]
+
+
+def reco_to_topk_lists(reco: pd.DataFrame, k: int) -> list[list[int]]:
+ ordered = reco[reco[Columns.Rank] <= k].sort_values([Columns.User, Columns.Rank])
+ grouped = ordered.groupby(Columns.User, sort=False)[Columns.Item].agg(list)
+ return grouped.tolist()
+
+
+def evaluate_recommendations(
+ reco: pd.DataFrame,
+ targets: pd.DataFrame,
+ catalog_size: int,
+ top_k_main: int,
+ top_k_hit: int,
+) -> dict[str, float]:
+ return {
+ f"NDCG@{top_k_main}": NDCG(k=top_k_main).calc(reco, targets),
+ f"HR@{top_k_hit}": HitRate(k=top_k_hit).calc(reco, targets),
+ f"MRR@{top_k_main}": MRR(k=top_k_main).calc(reco, targets),
+ f"Coverage@{top_k_main}": CatalogCoverage(k=top_k_main, normalize=True).calc(
+ reco, catalog=list(range(catalog_size))
+ ),
+ f"Gini@{top_k_main}": gini_k(reco_to_topk_lists(reco, top_k_main), top_k_main),
+ }
+
+
+def train_sasrec(
+ dataset: Dataset,
+ device: str,
+ max_epochs: int,
+ patience: int,
+ batch_size: int,
+ num_workers: int,
+ max_length: int,
+ min_user_interactions: int,
+) -> SASRecModel:
+ accelerator = "gpu" if device.startswith("cuda") else device
+ trainer_kwargs = {
+ "accelerator": accelerator,
+ "devices": 1,
+ "min_epochs": 1,
+ "max_epochs": max_epochs,
+ "callbacks": [
+ EarlyStopping(
+ monitor=SASRecModel.val_loss_name,
+ patience=patience,
+ mode="min",
+ )
+ ],
+ "deterministic": True,
+ }
+ if device.startswith("cuda"):
+ trainer_kwargs["precision"] = "bf16-mixed"
+
+ trainer = Trainer(**trainer_kwargs)
+
+ model = SASRecModel(
+ n_factors=200,
+ n_blocks=1,
+ n_heads=2,
+ dropout_rate=0.1,
+ train_min_user_interactions=min_user_interactions,
+ session_max_len=max_length,
+ item_net_block_types=(IdEmbeddingsItemNet,),
+ get_val_mask_func=leave_one_out_mask,
+ batch_size=batch_size,
+ dataloader_num_workers=num_workers,
+ verbose=1,
+ deterministic=True,
+ recommend_torch_device=device,
+ )
+ model._trainer = trainer
+ model.fit(dataset)
+ return model
+
+
+def cleanup() -> None:
+ gc.collect()
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+
+def benchmark_sasrec(
+ val_df: pd.DataFrame,
+ test_targets: pd.DataFrame,
+ device: str,
+ args: argparse.Namespace,
+) -> tuple[dict[str, float], float, float]:
+ print("Converting validation history to RecTools dataset for SASRec")
+ dataset = make_sasrec_dataset(val_df)
+
+ print("Fitting RecTools SASRec")
+ fit_start = time.perf_counter()
+ model = train_sasrec(
+ dataset=dataset,
+ device=device,
+ max_epochs=args.sasrec_epochs,
+ patience=args.sasrec_patience,
+ batch_size=args.sasrec_batch_size,
+ num_workers=args.sasrec_num_workers,
+ max_length=args.max_length,
+ min_user_interactions=args.min_user_interactions,
+ )
+ train_seconds = time.perf_counter() - fit_start
+
+ users = test_targets["user_id"].to_numpy()
+ inference_start = time.perf_counter()
+ reco = model.recommend(
+ users=users,
+ dataset=dataset,
+ k=max(args.top_k_main, args.top_k_hit),
+ filter_viewed=False,
+ )
+ inference_seconds = time.perf_counter() - inference_start
+
+ metrics = evaluate_recommendations(
+ reco=reco,
+ targets=test_targets,
+ catalog_size=val_df["item_id"].nunique(),
+ top_k_main=args.top_k_main,
+ top_k_hit=args.top_k_hit,
+ )
+ return metrics, train_seconds, inference_seconds
+
+
+def benchmark_tiger(
+ train_df: pd.DataFrame,
+ val_df: pd.DataFrame,
+ test_targets: pd.DataFrame,
+ meta: pd.DataFrame,
+ device: str,
+ args: argparse.Namespace,
+) -> tuple[dict[str, float], float, float]:
+ meta = meta.sort_values("item_id").drop_duplicates("item_id")
+ embeddings = encode_texts(
+ meta["text"].tolist(),
+ model_name=args.embedding_model,
+ device=device,
+ batch_size=args.tokenizer_batch_size,
+ )
+ tokenizer = SIDTokenizer(
+ input_dim=embeddings.shape[1],
+ codebook_sizes=[256, 256, 256],
+ codebook_dim=32,
+ hidden_dims=[768, 256, 128, 64],
+ quantizer="rqvae",
+ device=device,
+ )
+ tokenizer.fit(
+ item_ids=meta["item_id"].tolist(),
+ embeddings=embeddings,
+ max_epochs=args.tokenizer_max_epochs,
+ patience=args.tokenizer_patience,
+ batch_size=args.tokenizer_batch_size,
+ save_dir=str(args.workdir / "tokenizer_checkpoints"),
+ )
+
+ tiger = TIGERModel(
+ tokenizer=tokenizer,
+ hidden_units=128,
+ num_blocks=4,
+ num_heads=6,
+ dropout_rate=0.1,
+ max_length=args.max_length,
+ lr=1e-3,
+ lr_schedule="cosine",
+ warmup_steps=0,
+ max_epochs=args.tiger_epochs,
+ patience=args.tiger_patience,
+ batch_size=args.tiger_batch_size,
+ eval_batch_size=args.tiger_eval_batch_size,
+ num_workers=args.tiger_num_workers,
+ beam_size=max(20, args.top_k_hit),
+ top_k=args.top_k_main,
+ d_kv=64,
+ ff_dim=1024,
+ device=device,
+ random_seed=args.seed,
+ )
+
+ print("Fitting TIGER")
+ fit_start = time.perf_counter()
+ tiger.fit(train_df=train_df, val_df=val_df)
+ train_seconds = time.perf_counter() - fit_start
+
+ inference_start = time.perf_counter()
+ reco = tiger.predict(val_df, top_k=max(args.top_k_main, args.top_k_hit))
+ inference_seconds = time.perf_counter() - inference_start
+
+ metrics = evaluate_recommendations(
+ reco=reco,
+ targets=test_targets,
+ catalog_size=val_df["item_id"].nunique(),
+ top_k_main=args.top_k_main,
+ top_k_hit=args.top_k_hit,
+ )
+ return metrics, train_seconds, inference_seconds
+
+
+def collect_dataset_info(
+ interactions: pd.DataFrame,
+ train_df: pd.DataFrame,
+ val_df: pd.DataFrame,
+ test_df: pd.DataFrame,
+ meta: pd.DataFrame,
+) -> dict[str, object]:
+ timestamps = interactions["timestamp"]
+ return {
+ "n_interactions": len(interactions),
+ "n_users": interactions["user_id"].nunique(),
+ "n_items": interactions["item_id"].nunique(),
+ "n_train": len(train_df),
+ "n_val": len(val_df),
+ "n_test": len(test_df),
+ "n_meta_items": len(meta),
+ "min_timestamp": timestamps.min(),
+ "max_timestamp": timestamps.max(),
+ "avg_interactions_per_user": len(interactions)
+ / interactions["user_id"].nunique(),
+ "avg_interactions_per_item": len(interactions)
+ / interactions["item_id"].nunique(),
+ }
+
+
+def write_report(
+ report_path: Path,
+ args: argparse.Namespace,
+ device: str,
+ data_info: dict[str, object],
+ preprocessed_paths: dict[str, Path],
+ results: dict[str, dict[str, float]],
+ started_at: datetime,
+ finished_at: datetime,
+) -> None:
+ report_path.parent.mkdir(parents=True, exist_ok=True)
+ report_date = finished_at.strftime("%Y-%m-%d %H:%M:%S")
+ duration_seconds = (finished_at - started_at).total_seconds()
+ dataset_label = "MovieLens 20M"
+
+ metric_names = [
+ f"NDCG@{args.top_k_main}",
+ f"HR@{args.top_k_hit}",
+ f"MRR@{args.top_k_main}",
+ f"Coverage@{args.top_k_main}",
+ f"Gini@{args.top_k_main}",
+ ]
+
+ lines = [
+ "# SASRec vs TIGER Comparison",
+ "",
+ f"**Date:** {report_date} ",
+ f"**Duration:** {duration_seconds:.1f}s ",
+ f"**Device:** {describe_device(device)} ",
+ f"**Dataset:** {dataset_label} ",
+ f"**Report path:** `{report_path}`",
+ "",
+ "## Run Summary",
+ "",
+ "| Field | Value |",
+ "|---|---|",
+ f"| Data directory | `{args.workdir}` |",
+ f"| Ratings source | `{args.workdir / 'ratings.csv'}` |",
+ f"| Movies source | `{args.workdir / 'movies.csv'}` |",
+ f"| Embedding model | `{args.embedding_model}` |",
+ f"| Torch version | `{torch.__version__}` |",
+ f"| PyTorch Lightning version | `{pl.__version__}` |",
+ f"| Seed | {args.seed} |",
+ f"| Max sequence length | {args.max_length} |",
+ f"| top_k_main | {args.top_k_main} |",
+ f"| top_k_hit | {args.top_k_hit} |",
+ "",
+ "## Dataset Statistics",
+ "",
+ "| Statistic | Value |",
+ "|---|---:|",
+ f"| Interactions | {data_info['n_interactions']:,} |",
+ f"| Users | {data_info['n_users']:,} |",
+ f"| Items | {data_info['n_items']:,} |",
+ f"| Metadata rows | {data_info['n_meta_items']:,} |",
+ f"| Train interactions | {data_info['n_train']:,} |",
+ f"| Validation interactions | {data_info['n_val']:,} |",
+ f"| Test interactions | {data_info['n_test']:,} |",
+ f"| Avg interactions per user | {data_info['avg_interactions_per_user']:.2f} |",
+ f"| Avg interactions per item | {data_info['avg_interactions_per_item']:.2f} |",
+ f"| First timestamp | {data_info['min_timestamp']} |",
+ f"| Last timestamp | {data_info['max_timestamp']} |",
+ "",
+ "## Filtering and Data Preparation",
+ "",
+ "| Parameter | Value |",
+ "|---|---|",
+ f"| Minimum rating | {args.min_rating} |",
+ f"| Minimum item interactions | {args.min_item_interactions} |",
+ f"| Minimum user interactions | {args.min_user_interactions} |",
+ f"| User cap | {args.limit_users if args.limit_users is not None else 'None'} |",
+ "| Interaction count filter | Iterative k-core |",
+ "| Split strategy | Leave-one-out per user |",
+ "| TIGER metadata text | `title + genres` from `movies.csv` |",
+ "",
+ "## Preprocessed Files",
+ "",
+ "| File | Path |",
+ "|---|---|",
+ f"| Interactions | `{preprocessed_paths['interactions']}` |",
+ f"| Item metadata | `{preprocessed_paths['item_metadata']}` |",
+ f"| Train split | `{preprocessed_paths['train']}` |",
+ f"| Validation split | `{preprocessed_paths['val']}` |",
+ f"| Test split | `{preprocessed_paths['test']}` |",
+ f"| Test targets | `{preprocessed_paths['test_targets']}` |",
+ "",
+ "## Model Configuration",
+ "",
+ "| Parameter | SASRec | TIGER |",
+ "|---|---|---|",
+ f"| Epochs | {args.sasrec_epochs} | {args.tiger_epochs} |",
+ f"| Patience | {args.sasrec_patience} | {args.tiger_patience} |",
+ f"| Batch size | {args.sasrec_batch_size} | {args.tiger_batch_size} |",
+ f"| Eval batch size | - | {args.tiger_eval_batch_size} |",
+ f"| Num workers | {args.sasrec_num_workers} | {args.tiger_num_workers} |",
+ f"| Max length | {args.max_length} | {args.max_length} |",
+ "| Architecture notes | RecTools SASRec defaults from this script | TIGER + SIDTokenizer (RQ-VAE) |",
+ "",
+ "## Results",
+ "",
+ "| Model | Train time, s | Inference time, s | "
+ + " | ".join(metric_names)
+ + " |",
+ "|---|---:|---:|" + "---:" * len(metric_names) + "|",
+ ]
+
+ for model_name, model_metrics in results.items():
+ row = [
+ model_name,
+ f"{model_metrics['train_seconds']:.3f}",
+ f"{model_metrics['inference_seconds']:.3f}",
+ ]
+ row.extend(f"{model_metrics[metric_name]:.6f}" for metric_name in metric_names)
+ lines.append("| " + " | ".join(row) + " |")
+
+ lines.extend(
+ [
+ "",
+ "## Notes",
+ "",
+ "- The script downloads ML-20M automatically if the data files are missing.",
+ "- TIGER uses stochastic SID collision resolution; this benchmark fixes it with the run seed.",
+ "- Coverage is computed against the total number of unique mapped items in the evaluation catalog.",
+ ]
+ )
+
+ report_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
+ print(f"Report written to {report_path}")
+
+
+def main() -> None:
+ args = parse_args()
+ started_at = datetime.now()
+ pl.seed_everything(args.seed, workers=True)
+
+ device = resolve_device(args.device)
+ args.workdir.mkdir(parents=True, exist_ok=True)
+
+ ratings_path, movies_path = download_ml20m(args.workdir)
+
+ interactions = load_interactions(ratings_path, args)
+ meta = build_meta(movies_path, interactions)
+ interactions, meta = map_ids(interactions, meta)
+
+ train_df, val_df, test_df = loo_split(interactions, timestamp_col="timestamp")
+ test_targets = get_last_item_targets(test_df)
+ preprocessed_paths = save_preprocessed_files(
+ workdir=args.workdir,
+ interactions=interactions,
+ meta=meta,
+ train_df=train_df,
+ val_df=val_df,
+ test_df=test_df,
+ test_targets=test_targets,
+ )
+ data_info = collect_dataset_info(interactions, train_df, val_df, test_df, meta)
+
+ print(
+ "Prepared dataset:",
+ f"{data_info['n_interactions']:,} interactions,",
+ f"{data_info['n_users']:,} users,",
+ f"{data_info['n_items']:,} items",
+ )
+
+ sasrec_metrics, sasrec_train_s, sasrec_infer_s = benchmark_sasrec(
+ val_df=val_df,
+ test_targets=test_targets,
+ device=device,
+ args=args,
+ )
+ cleanup()
+
+ tiger_metrics, tiger_train_s, tiger_infer_s = benchmark_tiger(
+ train_df=train_df,
+ val_df=val_df,
+ test_targets=test_targets,
+ meta=meta,
+ device=device,
+ args=args,
+ )
+
+ results = {
+ "SASRec": {
+ "train_seconds": sasrec_train_s,
+ "inference_seconds": sasrec_infer_s,
+ **sasrec_metrics,
+ },
+ "TIGER": {
+ "train_seconds": tiger_train_s,
+ "inference_seconds": tiger_infer_s,
+ **tiger_metrics,
+ },
+ }
+
+ finished_at = datetime.now()
+ write_report(
+ report_path=args.report_path,
+ args=args,
+ device=device,
+ data_info=data_info,
+ preprocessed_paths=preprocessed_paths,
+ results=results,
+ started_at=started_at,
+ finished_at=finished_at,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/tutorials/tiger_tutorial.ipynb b/examples/tutorials/tiger_tutorial.ipynb
new file mode 100644
index 00000000..f4bc37d3
--- /dev/null
+++ b/examples/tutorials/tiger_tutorial.ipynb
@@ -0,0 +1,1711 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "intro",
+ "metadata": {},
+ "source": [
+ "# TIGER: Generative Retrieval with Semantic IDs\n",
+ "\n",
+ "This tutorial walks through the TIGER generative retrieval model for sequential recommendation\n",
+ "using the [MovieLens 20M](https://grouplens.org/datasets/movielens/20m/) dataset.\n",
+ "\n",
+ "TIGER represents items as **Semantic IDs** -- tuples of discrete codes produced by RQ-VAE.\n",
+ "A Transformer encoder-decoder is then trained to predict the next item's Semantic ID given a user's interaction history.\n",
+ "\n",
+ "### Pipeline overview\n",
+ "\n",
+ "1. Download and prepare ML-20M interaction data\n",
+ "2. Apply k-core filtering\n",
+ "3. Train a `SIDTokenizer` (with random embeddings for this demo)\n",
+ "4. Create a `TIGERModel` with the tokenizer\n",
+ "5. Split data with `loo_split` (leave-one-out)\n",
+ "6. Train with `fit(train_df, val_df)`\n",
+ "7. Evaluate with `evaluate(test_df)`\n",
+ "8. Generate recommendations with `predict(interactions)`\n",
+ "9. Save / load the trained model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "setup-header",
+ "metadata": {},
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "imports",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import warnings\n",
+ "\n",
+ "from rectools.semantic.data_handling import k_core, loo_split\n",
+ "from rectools.semantic.tokenizer import SIDTokenizer\n",
+ "from rectools.semantic.tiger import TIGERModel\n",
+ "\n",
+ "warnings.simplefilter(\"ignore\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "data-header",
+ "metadata": {},
+ "source": [
+ "## 1. Download and prepare ML-20M\n",
+ "\n",
+ "We download the MovieLens 20M dataset and prepare it as a DataFrame with\n",
+ "`user_id`, `item_id`, and `timestamp` columns."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "download-data",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[ml-20m.zip]\n",
+ " End-of-central-directory signature not found. Either this file is not\n",
+ " a zipfile, or it constitutes one disk of a multi-part archive. In the\n",
+ " latter case the central directory and zipfile comment will be found on\n",
+ " the last disk(s) of this archive.\n",
+ "unzip: cannot find zipfile directory in one of ml-20m.zip or\n",
+ " ml-20m.zip.zip, and cannot find ml-20m.zip.ZIP, period.\n",
+ "CPU times: user 8.84 ms, sys: 41.8 ms, total: 50.7 ms\n",
+ "Wall time: 794 ms\n"
+ ]
+ }
+ ],
+ "source": [
+ "%%time\n",
+ "!wget -q https://files.grouplens.org/datasets/movielens/ml-20m.zip -O ml-20m.zip\n",
+ "!unzip -o -q ml-20m.zip\n",
+ "!rm ml-20m.zip"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "load-data",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Interactions: 9,995,410\n",
+ "Users: 138,287, Items: 20,720\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " timestamp | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 1 | \n",
+ " 1079 | \n",
+ " 1094785665 | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 1 | \n",
+ " 2959 | \n",
+ " 1094785698 | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 1 | \n",
+ " 3996 | \n",
+ " 1094785727 | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 1 | \n",
+ " 151 | \n",
+ " 1094785734 | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 1 | \n",
+ " 1374 | \n",
+ " 1094785746 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id item_id timestamp\n",
+ "0 1 1079 1094785665\n",
+ "1 1 2959 1094785698\n",
+ "2 1 3996 1094785727\n",
+ "3 1 151 1094785734\n",
+ "4 1 1374 1094785746"
+ ]
+ },
+ "execution_count": 3,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ratings = pd.read_csv(\"ml-20m/ratings.csv\")\n",
+ "ratings.columns = [\"user_id\", \"item_id\", \"rating\", \"timestamp\"]\n",
+ "\n",
+ "# Keep only positive interactions (rating >= 4)\n",
+ "interactions = ratings[ratings[\"rating\"] >= 4][[\"user_id\", \"item_id\", \"timestamp\"]].copy()\n",
+ "interactions = interactions.sort_values([\"user_id\", \"timestamp\"]).reset_index(drop=True)\n",
+ "\n",
+ "print(f\"Interactions: {len(interactions):,}\")\n",
+ "print(f\"Users: {interactions['user_id'].nunique():,}, Items: {interactions['item_id'].nunique():,}\")\n",
+ "interactions.head()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "load-movies",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " item_id | \n",
+ " title | \n",
+ " genres | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 1 | \n",
+ " Toy Story (1995) | \n",
+ " Adventure|Animation|Children|Comedy|Fantasy | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 2 | \n",
+ " Jumanji (1995) | \n",
+ " Adventure|Children|Fantasy | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 3 | \n",
+ " Grumpier Old Men (1995) | \n",
+ " Comedy|Romance | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 4 | \n",
+ " Waiting to Exhale (1995) | \n",
+ " Comedy|Drama|Romance | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 5 | \n",
+ " Father of the Bride Part II (1995) | \n",
+ " Comedy | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " item_id title \\\n",
+ "0 1 Toy Story (1995) \n",
+ "1 2 Jumanji (1995) \n",
+ "2 3 Grumpier Old Men (1995) \n",
+ "3 4 Waiting to Exhale (1995) \n",
+ "4 5 Father of the Bride Part II (1995) \n",
+ "\n",
+ " genres \n",
+ "0 Adventure|Animation|Children|Comedy|Fantasy \n",
+ "1 Adventure|Children|Fantasy \n",
+ "2 Comedy|Romance \n",
+ "3 Comedy|Drama|Romance \n",
+ "4 Comedy "
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "movies = pd.read_csv(\"ml-20m/movies.csv\")\n",
+ "movies.columns = [\"item_id\", \"title\", \"genres\"]\n",
+ "movies.head()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "kcore-header",
+ "metadata": {},
+ "source": [
+ "## 2. K-core filtering\n",
+ "\n",
+ "`k_core()` iteratively removes users and items with fewer than the specified minimum interactions.\n",
+ "This ensures that every user and item has enough signal for sequential modeling."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "kcore",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "After k-core: 9,977,451 interactions\n",
+ "Users: 136,674, Items: 13,680\n"
+ ]
+ }
+ ],
+ "source": [
+ "interactions = k_core(\n",
+ " interactions,\n",
+ " user_min_interactions=5,\n",
+ " item_min_interactions=5,\n",
+ ")\n",
+ "\n",
+ "print(f\"After k-core: {len(interactions):,} interactions\")\n",
+ "print(f\"Users: {interactions['user_id'].nunique():,}, Items: {interactions['item_id'].nunique():,}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "tokenizer-header",
+ "metadata": {},
+ "source": [
+ "## 3. Train a tokenizer\n",
+ "\n",
+ "The `SIDTokenizer` maps each item ID to a **Semantic ID** -- a tuple of discrete codes.\n",
+ "It uses RQ-VAE (or RK-means) to quantize item embeddings into hierarchical codebooks.\n",
+ "\n",
+ "### Embedding items with sentence-transformers\n",
+ "\n",
+ "Here we embed item metadata (title, genres, description) using\n",
+ "a sentence-transformers model and pass the resulting embeddings to `SIDTokenizer.fit()`:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "4f6e3d21",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "from sentence_transformers import SentenceTransformer\n",
+ "from tqdm.auto import trange\n",
+ "\n",
+ "def encode_texts(texts: list, model_name: str, device: str = \"cpu\", batch_size: int = 256) -> np.ndarray:\n",
+ " \"\"\"Encode texts into embeddings using a SentenceTransformer model.\n",
+ "\n",
+ " Parameters\n",
+ " ----------\n",
+ " texts : list of str\n",
+ " Texts to encode.\n",
+ " model_name : str\n",
+ " Name or path of a sentence-transformers model.\n",
+ " device : str\n",
+ " Device to run the model on.\n",
+ " batch_size : int\n",
+ " Batch size for encoding.\n",
+ "\n",
+ " Returns\n",
+ " -------\n",
+ " np.ndarray\n",
+ " Embedding matrix of shape ``(len(texts), embed_dim)``.\n",
+ " \"\"\"\n",
+ " embeds = []\n",
+ " model = SentenceTransformer(model_name).to(device).eval()\n",
+ " with torch.no_grad():\n",
+ " for start in trange(0, len(texts), batch_size, desc=\"Embedding the metadata\"):\n",
+ " end = min(start + batch_size, len(texts))\n",
+ " embeds.append(model.encode(texts[start:end]))\n",
+ "\n",
+ " return np.vstack(embeds)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "2ea6e31b-0a9c-4fa3-8464-eb35df9dd87b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0d0b8b2fd59648838f70c6f0a4609108",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Embedding the metadata: 0%| | 0/54 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "f7c4f8b15a9a43639457791bad946562",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Initializing codebooks with constrained k-means: 0%| | 0/3 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c8d2226f8fa54b23a02d76da684509e3",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/700 [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Early stopping at epoch 22 (patience exceeded 10)\n",
+ "Best model's collision rate: 0.42%\n",
+ "Tokenizer vocabulary: 13651 unique SIDs\n"
+ ]
+ }
+ ],
+ "source": [
+ "from rectools.semantic.tokenizer import SIDTokenizer\n",
+ "\n",
+ "# Build text descriptions from movie metadata\n",
+ "meta = movies[movies[\"item_id\"].isin(interactions[\"item_id\"])].copy()\n",
+ "meta[\"text\"] = meta[\"title\"] + \" \" + meta[\"genres\"].str.replace(\"|\", \" \")\n",
+ "\n",
+ "# Encode with a sentence-transformers model\n",
+ "embeddings = encode_texts(\n",
+ " meta[\"text\"].tolist(),\n",
+ " model_name=\"Qwen/Qwen3-Embedding-0.6B\",\n",
+ " device=\"cuda\",\n",
+ ")\n",
+ "\n",
+ "# Train the tokenizer\n",
+ "tokenizer = SIDTokenizer(\n",
+ " input_dim=embeddings.shape[1],\n",
+ " codebook_sizes=[256, 256, 256],\n",
+ " codebook_dim=32,\n",
+ " quantizer=\"rqvae\",\n",
+ " device=\"cuda\",\n",
+ ")\n",
+ "tokenizer.fit(\n",
+ " item_ids=meta[\"item_id\"].tolist(),\n",
+ " embeddings=embeddings,\n",
+ " max_epochs=100,\n",
+ " patience=10,\n",
+ ")\n",
+ "\n",
+ "print(f\"Tokenizer vocabulary: {len(tokenizer)} unique SIDs\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1d8f8529-bf4d-47e2-9b20-602ad318fd6b",
+ "metadata": {},
+ "source": [
+ "Let's see which items get duplicate semantic IDs:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "99e068e3-4781-43f7-966e-4c679f71068f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " SID | \n",
+ " item_id | \n",
+ " text | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " (235, 229, 202) | \n",
+ " 421 | \n",
+ " Black Beauty (1994) Adventure Children Drama | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " (235, 229, 202) | \n",
+ " 986 | \n",
+ " Fly Away Home (1996) Adventure Children | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " (10, 64, 101) | \n",
+ " 502 | \n",
+ " Next Karate Kid, The (1994) Action Children Ro... | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " (10, 64, 101) | \n",
+ " 2422 | \n",
+ " Karate Kid, Part III, The (1989) Action Advent... | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " (15, 132, 12) | \n",
+ " 913 | \n",
+ " Maltese Falcon, The (1941) Film-Noir Mystery | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " (15, 132, 12) | \n",
+ " 8228 | \n",
+ " Maltese Falcon, The (a.k.a. Dangerous Female) ... | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " (208, 107, 255) | \n",
+ " 1007 | \n",
+ " Apple Dumpling Gang, The (1975) Children Comed... | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " (208, 107, 255) | \n",
+ " 2016 | \n",
+ " Apple Dumpling Gang Rides Again, The (1979) Ch... | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " (99, 84, 111) | \n",
+ " 1644 | \n",
+ " I Know What You Did Last Summer (1997) Horror ... | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " (99, 84, 111) | \n",
+ " 2338 | \n",
+ " I Still Know What You Did Last Summer (1998) H... | \n",
+ "
\n",
+ " \n",
+ " | 10 | \n",
+ " (183, 144, 185) | \n",
+ " 2124 | \n",
+ " Addams Family, The (1991) Children Comedy Fantasy | \n",
+ "
\n",
+ " \n",
+ " | 11 | \n",
+ " (183, 144, 185) | \n",
+ " 27075 | \n",
+ " Addams Family Reunion (1998) Children Comedy F... | \n",
+ "
\n",
+ " \n",
+ " | 12 | \n",
+ " (93, 110, 17) | \n",
+ " 3068 | \n",
+ " Verdict, The (1982) Drama Mystery | \n",
+ "
\n",
+ " \n",
+ " | 13 | \n",
+ " (93, 110, 17) | \n",
+ " 82121 | \n",
+ " Verdict, The (1946) Crime Drama Film-Noir Myst... | \n",
+ "
\n",
+ " \n",
+ " | 14 | \n",
+ " (180, 31, 0) | \n",
+ " 3149 | \n",
+ " Diamonds (1999) Mystery | \n",
+ "
\n",
+ " \n",
+ " | 15 | \n",
+ " (180, 31, 0) | \n",
+ " 3219 | \n",
+ " Pacific Heights (1990) Mystery Thriller | \n",
+ "
\n",
+ " \n",
+ " | 16 | \n",
+ " (180, 60, 223) | \n",
+ " 3660 | \n",
+ " Puppet Master (1989) Horror Sci-Fi Thriller | \n",
+ "
\n",
+ " \n",
+ " | 17 | \n",
+ " (180, 60, 223) | \n",
+ " 3661 | \n",
+ " Puppet Master II (1991) Horror Sci-Fi Thriller | \n",
+ "
\n",
+ " \n",
+ " | 18 | \n",
+ " (179, 99, 81) | \n",
+ " 4544 | \n",
+ " Short Circuit 2 (1988) Comedy Sci-Fi | \n",
+ "
\n",
+ " \n",
+ " | 19 | \n",
+ " (179, 99, 81) | \n",
+ " 4545 | \n",
+ " Short Circuit (1986) Comedy Sci-Fi | \n",
+ "
\n",
+ " \n",
+ " | 20 | \n",
+ " (31, 154, 77) | \n",
+ " 5575 | \n",
+ " Alias Betty (Betty Fisher et autres histoires)... | \n",
+ "
\n",
+ " \n",
+ " | 21 | \n",
+ " (31, 154, 77) | \n",
+ " 35082 | \n",
+ " Lila Says (Lila dit ça) (2004) Crime Drama Rom... | \n",
+ "
\n",
+ " \n",
+ " | 22 | \n",
+ " (11, 204, 216) | \n",
+ " 5736 | \n",
+ " Faces of Death 3 (1985) Documentary Horror | \n",
+ "
\n",
+ " \n",
+ " | 23 | \n",
+ " (11, 204, 216) | \n",
+ " 5739 | \n",
+ " Faces of Death 6 (1996) Documentary Horror | \n",
+ "
\n",
+ " \n",
+ " | 24 | \n",
+ " (60, 149, 155) | \n",
+ " 6159 | \n",
+ " All the Real Girls (2003) Drama Romance | \n",
+ "
\n",
+ " \n",
+ " | 25 | \n",
+ " (60, 149, 155) | \n",
+ " 67365 | \n",
+ " After Sex (2007) Drama Romance | \n",
+ "
\n",
+ " \n",
+ " | 26 | \n",
+ " (2, 109, 34) | \n",
+ " 6407 | \n",
+ " Walk, Don't Run (1966) Comedy Romance | \n",
+ "
\n",
+ " \n",
+ " | 27 | \n",
+ " (2, 109, 34) | \n",
+ " 30742 | \n",
+ " Some Came Running (1958) Drama Romance | \n",
+ "
\n",
+ " \n",
+ " | 28 | \n",
+ " (14, 164, 68) | \n",
+ " 6428 | \n",
+ " Two Mules for Sister Sara (1970) Comedy War We... | \n",
+ "
\n",
+ " \n",
+ " | 29 | \n",
+ " (14, 164, 68) | \n",
+ " 7897 | \n",
+ " Ballad of Cable Hogue, The (1970) Comedy Western | \n",
+ "
\n",
+ " \n",
+ " | 30 | \n",
+ " (230, 2, 206) | \n",
+ " 8730 | \n",
+ " To End All Wars (2001) Action Drama War | \n",
+ "
\n",
+ " \n",
+ " | 31 | \n",
+ " (230, 2, 206) | \n",
+ " 102445 | \n",
+ " Star Trek Into Darkness (2013) Action Adventur... | \n",
+ "
\n",
+ " \n",
+ " | 32 | \n",
+ " (209, 218, 225) | \n",
+ " 8925 | \n",
+ " Spinning Boris (2003) Comedy Drama | \n",
+ "
\n",
+ " \n",
+ " | 33 | \n",
+ " (209, 218, 225) | \n",
+ " 55190 | \n",
+ " Love and Other Disasters (2006) Comedy Romance | \n",
+ "
\n",
+ " \n",
+ " | 34 | \n",
+ " (199, 49, 227) | \n",
+ " 8938 | \n",
+ " Tarnation (2003) Documentary | \n",
+ "
\n",
+ " \n",
+ " | 35 | \n",
+ " (199, 49, 227) | \n",
+ " 110387 | \n",
+ " Unknown Known, The (2013) Documentary | \n",
+ "
\n",
+ " \n",
+ " | 36 | \n",
+ " (8, 220, 31) | \n",
+ " 26007 | \n",
+ " Unknown Soldier, The (Tuntematon sotilas) (195... | \n",
+ "
\n",
+ " \n",
+ " | 37 | \n",
+ " (8, 220, 31) | \n",
+ " 26560 | \n",
+ " Unknown Soldier, The (Tuntematon sotilas) (198... | \n",
+ "
\n",
+ " \n",
+ " | 38 | \n",
+ " (223, 206, 99) | \n",
+ " 26270 | \n",
+ " Lone Wolf and Cub: Baby Cart at the River Styx... | \n",
+ "
\n",
+ " \n",
+ " | 39 | \n",
+ " (223, 206, 99) | \n",
+ " 62803 | \n",
+ " Lone Wolf and Cub: Baby Cart in Peril (Kozure ... | \n",
+ "
\n",
+ " \n",
+ " | 40 | \n",
+ " (197, 115, 131) | \n",
+ " 43908 | \n",
+ " London (2005) Drama | \n",
+ "
\n",
+ " \n",
+ " | 41 | \n",
+ " (197, 115, 131) | \n",
+ " 70948 | \n",
+ " London (1994) Documentary | \n",
+ "
\n",
+ " \n",
+ " | 42 | \n",
+ " (157, 159, 6) | \n",
+ " 59315 | \n",
+ " Iron Man (2008) Action Adventure Sci-Fi | \n",
+ "
\n",
+ " \n",
+ " | 43 | \n",
+ " (157, 159, 6) | \n",
+ " 77561 | \n",
+ " Iron Man 2 (2010) Action Adventure Sci-Fi Thri... | \n",
+ "
\n",
+ " \n",
+ " | 44 | \n",
+ " (116, 33, 244) | \n",
+ " 61026 | \n",
+ " Red Cliff (Chi bi) (2008) Action Adventure Dra... | \n",
+ "
\n",
+ " \n",
+ " | 45 | \n",
+ " (116, 33, 244) | \n",
+ " 68486 | \n",
+ " Red Cliff Part II (Chi Bi Xia: Jue Zhan Tian X... | \n",
+ "
\n",
+ " \n",
+ " | 46 | \n",
+ " (196, 29, 161) | \n",
+ " 62390 | \n",
+ " Autism: The Musical (2007) Documentary | \n",
+ "
\n",
+ " \n",
+ " | 47 | \n",
+ " (196, 29, 161) | \n",
+ " 85190 | \n",
+ " Public Speaking (2010) Documentary | \n",
+ "
\n",
+ " \n",
+ " | 48 | \n",
+ " (109, 81, 219) | \n",
+ " 65682 | \n",
+ " Underworld: Rise of the Lycans (2009) Action F... | \n",
+ "
\n",
+ " \n",
+ " | 49 | \n",
+ " (109, 81, 219) | \n",
+ " 91974 | \n",
+ " Underworld: Awakening (2012) Action Fantasy Ho... | \n",
+ "
\n",
+ " \n",
+ " | 50 | \n",
+ " (170, 55, 67) | \n",
+ " 77330 | \n",
+ " Red Riding: 1980 (2009) Crime Drama Mystery | \n",
+ "
\n",
+ " \n",
+ " | 51 | \n",
+ " (170, 55, 67) | \n",
+ " 77359 | \n",
+ " Red Riding: 1983 (2009) Crime Drama Mystery | \n",
+ "
\n",
+ " \n",
+ " | 52 | \n",
+ " (118, 203, 203) | \n",
+ " 86355 | \n",
+ " Atlas Shrugged: Part 1 (2011) Drama Mystery Sc... | \n",
+ "
\n",
+ " \n",
+ " | 53 | \n",
+ " (118, 203, 203) | \n",
+ " 97324 | \n",
+ " Atlas Shrugged: Part II (2012) Drama Mystery S... | \n",
+ "
\n",
+ " \n",
+ " | 54 | \n",
+ " (250, 128, 116) | \n",
+ " 88129 | \n",
+ " Drive (2011) Crime Drama Film-Noir Thriller | \n",
+ "
\n",
+ " \n",
+ " | 55 | \n",
+ " (250, 128, 116) | \n",
+ " 101137 | \n",
+ " Dead Man Down (2013) Action Crime Drama Romanc... | \n",
+ "
\n",
+ " \n",
+ " | 56 | \n",
+ " (34, 147, 143) | \n",
+ " 106204 | \n",
+ " Pieta (2013) Drama | \n",
+ "
\n",
+ " \n",
+ " | 57 | \n",
+ " (34, 147, 143) | \n",
+ " 111249 | \n",
+ " Belle (2013) Drama | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " SID item_id \\\n",
+ "0 (235, 229, 202) 421 \n",
+ "1 (235, 229, 202) 986 \n",
+ "2 (10, 64, 101) 502 \n",
+ "3 (10, 64, 101) 2422 \n",
+ "4 (15, 132, 12) 913 \n",
+ "5 (15, 132, 12) 8228 \n",
+ "6 (208, 107, 255) 1007 \n",
+ "7 (208, 107, 255) 2016 \n",
+ "8 (99, 84, 111) 1644 \n",
+ "9 (99, 84, 111) 2338 \n",
+ "10 (183, 144, 185) 2124 \n",
+ "11 (183, 144, 185) 27075 \n",
+ "12 (93, 110, 17) 3068 \n",
+ "13 (93, 110, 17) 82121 \n",
+ "14 (180, 31, 0) 3149 \n",
+ "15 (180, 31, 0) 3219 \n",
+ "16 (180, 60, 223) 3660 \n",
+ "17 (180, 60, 223) 3661 \n",
+ "18 (179, 99, 81) 4544 \n",
+ "19 (179, 99, 81) 4545 \n",
+ "20 (31, 154, 77) 5575 \n",
+ "21 (31, 154, 77) 35082 \n",
+ "22 (11, 204, 216) 5736 \n",
+ "23 (11, 204, 216) 5739 \n",
+ "24 (60, 149, 155) 6159 \n",
+ "25 (60, 149, 155) 67365 \n",
+ "26 (2, 109, 34) 6407 \n",
+ "27 (2, 109, 34) 30742 \n",
+ "28 (14, 164, 68) 6428 \n",
+ "29 (14, 164, 68) 7897 \n",
+ "30 (230, 2, 206) 8730 \n",
+ "31 (230, 2, 206) 102445 \n",
+ "32 (209, 218, 225) 8925 \n",
+ "33 (209, 218, 225) 55190 \n",
+ "34 (199, 49, 227) 8938 \n",
+ "35 (199, 49, 227) 110387 \n",
+ "36 (8, 220, 31) 26007 \n",
+ "37 (8, 220, 31) 26560 \n",
+ "38 (223, 206, 99) 26270 \n",
+ "39 (223, 206, 99) 62803 \n",
+ "40 (197, 115, 131) 43908 \n",
+ "41 (197, 115, 131) 70948 \n",
+ "42 (157, 159, 6) 59315 \n",
+ "43 (157, 159, 6) 77561 \n",
+ "44 (116, 33, 244) 61026 \n",
+ "45 (116, 33, 244) 68486 \n",
+ "46 (196, 29, 161) 62390 \n",
+ "47 (196, 29, 161) 85190 \n",
+ "48 (109, 81, 219) 65682 \n",
+ "49 (109, 81, 219) 91974 \n",
+ "50 (170, 55, 67) 77330 \n",
+ "51 (170, 55, 67) 77359 \n",
+ "52 (118, 203, 203) 86355 \n",
+ "53 (118, 203, 203) 97324 \n",
+ "54 (250, 128, 116) 88129 \n",
+ "55 (250, 128, 116) 101137 \n",
+ "56 (34, 147, 143) 106204 \n",
+ "57 (34, 147, 143) 111249 \n",
+ "\n",
+ " text \n",
+ "0 Black Beauty (1994) Adventure Children Drama \n",
+ "1 Fly Away Home (1996) Adventure Children \n",
+ "2 Next Karate Kid, The (1994) Action Children Ro... \n",
+ "3 Karate Kid, Part III, The (1989) Action Advent... \n",
+ "4 Maltese Falcon, The (1941) Film-Noir Mystery \n",
+ "5 Maltese Falcon, The (a.k.a. Dangerous Female) ... \n",
+ "6 Apple Dumpling Gang, The (1975) Children Comed... \n",
+ "7 Apple Dumpling Gang Rides Again, The (1979) Ch... \n",
+ "8 I Know What You Did Last Summer (1997) Horror ... \n",
+ "9 I Still Know What You Did Last Summer (1998) H... \n",
+ "10 Addams Family, The (1991) Children Comedy Fantasy \n",
+ "11 Addams Family Reunion (1998) Children Comedy F... \n",
+ "12 Verdict, The (1982) Drama Mystery \n",
+ "13 Verdict, The (1946) Crime Drama Film-Noir Myst... \n",
+ "14 Diamonds (1999) Mystery \n",
+ "15 Pacific Heights (1990) Mystery Thriller \n",
+ "16 Puppet Master (1989) Horror Sci-Fi Thriller \n",
+ "17 Puppet Master II (1991) Horror Sci-Fi Thriller \n",
+ "18 Short Circuit 2 (1988) Comedy Sci-Fi \n",
+ "19 Short Circuit (1986) Comedy Sci-Fi \n",
+ "20 Alias Betty (Betty Fisher et autres histoires)... \n",
+ "21 Lila Says (Lila dit ça) (2004) Crime Drama Rom... \n",
+ "22 Faces of Death 3 (1985) Documentary Horror \n",
+ "23 Faces of Death 6 (1996) Documentary Horror \n",
+ "24 All the Real Girls (2003) Drama Romance \n",
+ "25 After Sex (2007) Drama Romance \n",
+ "26 Walk, Don't Run (1966) Comedy Romance \n",
+ "27 Some Came Running (1958) Drama Romance \n",
+ "28 Two Mules for Sister Sara (1970) Comedy War We... \n",
+ "29 Ballad of Cable Hogue, The (1970) Comedy Western \n",
+ "30 To End All Wars (2001) Action Drama War \n",
+ "31 Star Trek Into Darkness (2013) Action Adventur... \n",
+ "32 Spinning Boris (2003) Comedy Drama \n",
+ "33 Love and Other Disasters (2006) Comedy Romance \n",
+ "34 Tarnation (2003) Documentary \n",
+ "35 Unknown Known, The (2013) Documentary \n",
+ "36 Unknown Soldier, The (Tuntematon sotilas) (195... \n",
+ "37 Unknown Soldier, The (Tuntematon sotilas) (198... \n",
+ "38 Lone Wolf and Cub: Baby Cart at the River Styx... \n",
+ "39 Lone Wolf and Cub: Baby Cart in Peril (Kozure ... \n",
+ "40 London (2005) Drama \n",
+ "41 London (1994) Documentary \n",
+ "42 Iron Man (2008) Action Adventure Sci-Fi \n",
+ "43 Iron Man 2 (2010) Action Adventure Sci-Fi Thri... \n",
+ "44 Red Cliff (Chi bi) (2008) Action Adventure Dra... \n",
+ "45 Red Cliff Part II (Chi Bi Xia: Jue Zhan Tian X... \n",
+ "46 Autism: The Musical (2007) Documentary \n",
+ "47 Public Speaking (2010) Documentary \n",
+ "48 Underworld: Rise of the Lycans (2009) Action F... \n",
+ "49 Underworld: Awakening (2012) Action Fantasy Ho... \n",
+ "50 Red Riding: 1980 (2009) Crime Drama Mystery \n",
+ "51 Red Riding: 1983 (2009) Crime Drama Mystery \n",
+ "52 Atlas Shrugged: Part 1 (2011) Drama Mystery Sc... \n",
+ "53 Atlas Shrugged: Part II (2012) Drama Mystery S... \n",
+ "54 Drive (2011) Crime Drama Film-Noir Thriller \n",
+ "55 Dead Man Down (2013) Action Crime Drama Romanc... \n",
+ "56 Pieta (2013) Drama \n",
+ "57 Belle (2013) Drama "
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "from utils import find_conflicts_df\n",
+ "\n",
+ "find_conflicts_df(tokenizer.id2sid, meta)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "model-header",
+ "metadata": {},
+ "source": [
+ "## 4. Create a TIGERModel\n",
+ "\n",
+ "The `TIGERModel` wraps the TIGER Transformer encoder-decoder.\n",
+ "It accepts a pretrained tokenizer and model/training hyperparameters."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "create-model",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tiger = TIGERModel(\n",
+ " tokenizer=tokenizer,\n",
+ " hidden_units=128,\n",
+ " num_blocks=4,\n",
+ " num_heads=6,\n",
+ " dropout_rate=0.1,\n",
+ " max_length=20,\n",
+ " # Training hyperparams\n",
+ " lr=1e-3,\n",
+ " lr_schedule=\"cosine\",\n",
+ " max_epochs=3,\n",
+ " patience=None,\n",
+ " batch_size=256,\n",
+ " eval_batch_size=64,\n",
+ " beam_size=20,\n",
+ " top_k=10,\n",
+ " d_kv=64,\n",
+ " device=\"cuda\"\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "split-header",
+ "metadata": {},
+ "source": [
+ "## 5. Split data with leave-one-out\n",
+ "\n",
+ "`loo_split()` performs a leave-one-out split on the interactions DataFrame:\n",
+ "- **train**: all interactions except the last 2 per user\n",
+ "- **val**: all interactions except the last 1 per user\n",
+ "- **test**: all interactions\n",
+ "\n",
+ "Users with fewer than 3 interactions are dropped."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "split",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train: 9,704,103 interactions, 136,674 users\n",
+ "Val: 9,840,777 interactions, 136,674 users\n",
+ "Test: 9,977,451 interactions, 136,674 users\n"
+ ]
+ }
+ ],
+ "source": [
+ "train_df, val_df, test_df = loo_split(interactions)\n",
+ "\n",
+ "print(f\"Train: {len(train_df):,} interactions, {train_df['user_id'].nunique():,} users\")\n",
+ "print(f\"Val: {len(val_df):,} interactions, {val_df['user_id'].nunique():,} users\")\n",
+ "print(f\"Test: {len(test_df):,} interactions, {test_df['user_id'].nunique():,} users\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "train-header",
+ "metadata": {},
+ "source": [
+ "## 6. Train the model\n",
+ "\n",
+ "`fit()` accepts pre-split train and validation DataFrames.\n",
+ "You control the split strategy -- the model does not split internally."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "train",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
+ "Trainer already configured with model summary callbacks: []. Skipping setting a default `ModelSummary` callback.\n",
+ "GPU available: True (cuda), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "HPU available: False, using: 0 HPUs\n",
+ "You are using a CUDA device ('NVIDIA H100 80GB HBM3 MIG 2g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n",
+ "\n",
+ " | Name | Type | Params | Mode \n",
+ "-------------------------------------------\n",
+ "0 | model | TIGERNet | 3.6 M | train\n",
+ "-------------------------------------------\n",
+ "3.6 M Trainable params\n",
+ "0 Non-trainable params\n",
+ "3.6 M Total params\n",
+ "14.455 Total estimated model params size (MB)\n",
+ "150 Modules in train mode\n",
+ "0 Modules in eval mode\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "82895486bd2043d3983219c3ac546537",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Sanity Checking: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "c05b7d44c7f8442b8da633b355bd5954",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Training: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1dfe8b62860d44bcb658f67ba97eb8cd",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1ae6f2e846cb4e35b0aa302e6a03fd1f",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0c9483bbff464e1eb63f3eb0e54ac88a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Validation: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "`Trainer.fit` stopped: `max_epochs=3` reached.\n"
+ ]
+ }
+ ],
+ "source": [
+ "tiger.fit(train_df, val_df)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "evaluate-header",
+ "metadata": {},
+ "source": [
+ "## 7. Evaluate the model\n",
+ "\n",
+ "`evaluate()` computes ranking metrics (Hit@k, NDCG@k, MRR@k) on a test set.\n",
+ "For each user, the last item in the sequence is treated as the ground-truth target."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "evaluate",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
+ "GPU available: True (cuda), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "HPU available: False, using: 0 HPUs\n",
+ "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "1035bdc9042c425b93823f7baf140be3",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "Testing: | | 0/? [00:00, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃ Test metric ┃ DataLoader 0 ┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│ Coverage@10 │ 0.0023603341542184353 │\n",
+ "│ Gini@10 │ 0.7614988684654236 │\n",
+ "│ Hit@10 │ 0.0728229209780693 │\n",
+ "│ MRR@10 │ 0.024664802476763725 │\n",
+ "│ NDCG@10 │ 0.03576871380209923 │\n",
+ "└───────────────────────────┴───────────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+ "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n",
+ "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+ "│\u001b[36m \u001b[0m\u001b[36m Coverage@10 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0023603341542184353 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m Gini@10 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7614988684654236 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m Hit@10 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0728229209780693 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m MRR@10 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.024664802476763725 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "│\u001b[36m \u001b[0m\u001b[36m NDCG@10 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.03576871380209923 \u001b[0m\u001b[35m \u001b[0m│\n",
+ "└───────────────────────────┴───────────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Hit@10: 0.0728\n",
+ "NDCG@10: 0.0358\n",
+ "MRR@10: 0.0247\n",
+ "Gini@10: 0.7615\n",
+ "Coverage@10: 0.0024\n"
+ ]
+ }
+ ],
+ "source": [
+ "results = tiger.evaluate(test_df)\n",
+ "\n",
+ "for metric, value in results.items():\n",
+ " print(f\"{metric}: {value:.4f}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "predict-header",
+ "metadata": {},
+ "source": [
+ "## 8. Generate recommendations\n",
+ "\n",
+ "`predict()` takes a DataFrame of user interaction histories and returns\n",
+ "a DataFrame of recommendations with columns: `user_id`, `item_id`, `score`, `rank`.\n",
+ "\n",
+ "We can join with movie metadata to see the actual titles."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "predict",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " user_id | \n",
+ " item_id | \n",
+ " score | \n",
+ " rank | \n",
+ " title | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " 1 | \n",
+ " 5952 | \n",
+ " -4.595025 | \n",
+ " 1 | \n",
+ " Lord of the Rings: The Two Towers, The (2002) | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " 1 | \n",
+ " 3578 | \n",
+ " -4.673381 | \n",
+ " 2 | \n",
+ " Gladiator (2000) | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " 1 | \n",
+ " 59315 | \n",
+ " -4.751242 | \n",
+ " 3 | \n",
+ " Iron Man (2008) | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " 1 | \n",
+ " 6539 | \n",
+ " -4.858047 | \n",
+ " 4 | \n",
+ " Pirates of the Caribbean: The Curse of the Bla... | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " 1 | \n",
+ " 5349 | \n",
+ " -5.812282 | \n",
+ " 5 | \n",
+ " Spider-Man (2002) | \n",
+ "
\n",
+ " \n",
+ " | 5 | \n",
+ " 2 | \n",
+ " 1196 | \n",
+ " -3.895408 | \n",
+ " 1 | \n",
+ " Star Wars: Episode V - The Empire Strikes Back... | \n",
+ "
\n",
+ " \n",
+ " | 6 | \n",
+ " 2 | \n",
+ " 260 | \n",
+ " -4.124645 | \n",
+ " 2 | \n",
+ " Star Wars: Episode IV - A New Hope (1977) | \n",
+ "
\n",
+ " \n",
+ " | 7 | \n",
+ " 2 | \n",
+ " 1210 | \n",
+ " -4.179870 | \n",
+ " 3 | \n",
+ " Star Wars: Episode VI - Return of the Jedi (1983) | \n",
+ "
\n",
+ " \n",
+ " | 8 | \n",
+ " 2 | \n",
+ " 541 | \n",
+ " -4.414659 | \n",
+ " 4 | \n",
+ " Blade Runner (1982) | \n",
+ "
\n",
+ " \n",
+ " | 9 | \n",
+ " 2 | \n",
+ " 1240 | \n",
+ " -4.478841 | \n",
+ " 5 | \n",
+ " Terminator, The (1984) | \n",
+ "
\n",
+ " \n",
+ " | 10 | \n",
+ " 3 | \n",
+ " 1210 | \n",
+ " -4.071971 | \n",
+ " 1 | \n",
+ " Star Wars: Episode VI - Return of the Jedi (1983) | \n",
+ "
\n",
+ " \n",
+ " | 11 | \n",
+ " 3 | \n",
+ " 780 | \n",
+ " -4.414012 | \n",
+ " 2 | \n",
+ " Independence Day (a.k.a. ID4) (1996) | \n",
+ "
\n",
+ " \n",
+ " | 12 | \n",
+ " 3 | \n",
+ " 260 | \n",
+ " -4.510668 | \n",
+ " 3 | \n",
+ " Star Wars: Episode IV - A New Hope (1977) | \n",
+ "
\n",
+ " \n",
+ " | 13 | \n",
+ " 3 | \n",
+ " 1196 | \n",
+ " -4.572967 | \n",
+ " 4 | \n",
+ " Star Wars: Episode V - The Empire Strikes Back... | \n",
+ "
\n",
+ " \n",
+ " | 14 | \n",
+ " 3 | \n",
+ " 1240 | \n",
+ " -4.647929 | \n",
+ " 5 | \n",
+ " Terminator, The (1984) | \n",
+ "
\n",
+ " \n",
+ " | 15 | \n",
+ " 4 | \n",
+ " 780 | \n",
+ " -3.860251 | \n",
+ " 1 | \n",
+ " Independence Day (a.k.a. ID4) (1996) | \n",
+ "
\n",
+ " \n",
+ " | 16 | \n",
+ " 4 | \n",
+ " 648 | \n",
+ " -4.608992 | \n",
+ " 2 | \n",
+ " Mission: Impossible (1996) | \n",
+ "
\n",
+ " \n",
+ " | 17 | \n",
+ " 4 | \n",
+ " 474 | \n",
+ " -4.694191 | \n",
+ " 3 | \n",
+ " In the Line of Fire (1993) | \n",
+ "
\n",
+ " \n",
+ " | 18 | \n",
+ " 4 | \n",
+ " 736 | \n",
+ " -4.714936 | \n",
+ " 4 | \n",
+ " Twister (1996) | \n",
+ "
\n",
+ " \n",
+ " | 19 | \n",
+ " 4 | \n",
+ " 95 | \n",
+ " -4.940084 | \n",
+ " 5 | \n",
+ " Broken Arrow (1996) | \n",
+ "
\n",
+ " \n",
+ " | 20 | \n",
+ " 5 | \n",
+ " 260 | \n",
+ " -4.890990 | \n",
+ " 1 | \n",
+ " Star Wars: Episode IV - A New Hope (1977) | \n",
+ "
\n",
+ " \n",
+ " | 21 | \n",
+ " 5 | \n",
+ " 1196 | \n",
+ " -5.130278 | \n",
+ " 2 | \n",
+ " Star Wars: Episode V - The Empire Strikes Back... | \n",
+ "
\n",
+ " \n",
+ " | 22 | \n",
+ " 5 | \n",
+ " 1198 | \n",
+ " -5.216189 | \n",
+ " 3 | \n",
+ " Raiders of the Lost Ark (Indiana Jones and the... | \n",
+ "
\n",
+ " \n",
+ " | 23 | \n",
+ " 5 | \n",
+ " 541 | \n",
+ " -5.227041 | \n",
+ " 4 | \n",
+ " Blade Runner (1982) | \n",
+ "
\n",
+ " \n",
+ " | 24 | \n",
+ " 5 | \n",
+ " 457 | \n",
+ " -5.267686 | \n",
+ " 5 | \n",
+ " Fugitive, The (1993) | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " user_id item_id score rank \\\n",
+ "0 1 5952 -4.595025 1 \n",
+ "1 1 3578 -4.673381 2 \n",
+ "2 1 59315 -4.751242 3 \n",
+ "3 1 6539 -4.858047 4 \n",
+ "4 1 5349 -5.812282 5 \n",
+ "5 2 1196 -3.895408 1 \n",
+ "6 2 260 -4.124645 2 \n",
+ "7 2 1210 -4.179870 3 \n",
+ "8 2 541 -4.414659 4 \n",
+ "9 2 1240 -4.478841 5 \n",
+ "10 3 1210 -4.071971 1 \n",
+ "11 3 780 -4.414012 2 \n",
+ "12 3 260 -4.510668 3 \n",
+ "13 3 1196 -4.572967 4 \n",
+ "14 3 1240 -4.647929 5 \n",
+ "15 4 780 -3.860251 1 \n",
+ "16 4 648 -4.608992 2 \n",
+ "17 4 474 -4.694191 3 \n",
+ "18 4 736 -4.714936 4 \n",
+ "19 4 95 -4.940084 5 \n",
+ "20 5 260 -4.890990 1 \n",
+ "21 5 1196 -5.130278 2 \n",
+ "22 5 1198 -5.216189 3 \n",
+ "23 5 541 -5.227041 4 \n",
+ "24 5 457 -5.267686 5 \n",
+ "\n",
+ " title \n",
+ "0 Lord of the Rings: The Two Towers, The (2002) \n",
+ "1 Gladiator (2000) \n",
+ "2 Iron Man (2008) \n",
+ "3 Pirates of the Caribbean: The Curse of the Bla... \n",
+ "4 Spider-Man (2002) \n",
+ "5 Star Wars: Episode V - The Empire Strikes Back... \n",
+ "6 Star Wars: Episode IV - A New Hope (1977) \n",
+ "7 Star Wars: Episode VI - Return of the Jedi (1983) \n",
+ "8 Blade Runner (1982) \n",
+ "9 Terminator, The (1984) \n",
+ "10 Star Wars: Episode VI - Return of the Jedi (1983) \n",
+ "11 Independence Day (a.k.a. ID4) (1996) \n",
+ "12 Star Wars: Episode IV - A New Hope (1977) \n",
+ "13 Star Wars: Episode V - The Empire Strikes Back... \n",
+ "14 Terminator, The (1984) \n",
+ "15 Independence Day (a.k.a. ID4) (1996) \n",
+ "16 Mission: Impossible (1996) \n",
+ "17 In the Line of Fire (1993) \n",
+ "18 Twister (1996) \n",
+ "19 Broken Arrow (1996) \n",
+ "20 Star Wars: Episode IV - A New Hope (1977) \n",
+ "21 Star Wars: Episode V - The Empire Strikes Back... \n",
+ "22 Raiders of the Lost Ark (Indiana Jones and the... \n",
+ "23 Blade Runner (1982) \n",
+ "24 Fugitive, The (1993) "
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "sample_users = interactions[\"user_id\"].unique()[:5]\n",
+ "sample_interactions = interactions[interactions[\"user_id\"].isin(sample_users)]\n",
+ "\n",
+ "recommendations = tiger.predict(sample_interactions, top_k=5)\n",
+ "recommendations.merge(movies[[\"item_id\", \"title\"]], on=\"item_id\", how=\"left\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "save-header",
+ "metadata": {},
+ "source": [
+ "## 9. Save and load the model\n",
+ "\n",
+ "`save()` writes the tokenizer, model weights, and config to a directory.\n",
+ "`load()` reconstructs the full `TIGERModel` from that directory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "save-load",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Saved to /tmp/tmp2mudyxqd/tiger_model\n",
+ "Contents: ['tokenizer.pt', 'model.pt', 'config.json']\n"
+ ]
+ },
+ {
+ "ename": "UnpicklingError",
+ "evalue": "Weights only load failed. This file can still be loaded, to do so you have two options, \u001b[1mdo those steps only if you trust the source of the checkpoint\u001b[0m. \n\t(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.\n\t(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.\n\tWeightsUnpickler error: Unsupported global: GLOBAL rectools.semantic.tokenizer.rqvae.RQVAE was not an allowed global by default. Please use `torch.serialization.add_safe_globals([rectools.semantic.tokenizer.rqvae.RQVAE])` or the `torch.serialization.safe_globals([rectools.semantic.tokenizer.rqvae.RQVAE])` context manager to allowlist this global if you trust this class/function.\n\nCheck the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mUnpicklingError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[14], line 10\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSaved to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msave_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mContents: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mos\u001b[38;5;241m.\u001b[39mlistdir(save_path)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 10\u001b[0m loaded_tiger \u001b[38;5;241m=\u001b[39m \u001b[43mTIGERModel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mLoaded model with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(loaded_tiger\u001b[38;5;241m.\u001b[39mtokenizer)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m item SIDs\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCodebook sizes: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mloaded_tiger\u001b[38;5;241m.\u001b[39mcodebook_sizes\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n",
+ "File \u001b[0;32m~/workspace/rectools/rectools/semantic/tiger/model.py:353\u001b[0m, in \u001b[0;36mTIGERModel.load\u001b[0;34m(cls, directory, device)\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(directory, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mconfig.json\u001b[39m\u001b[38;5;124m\"\u001b[39m), encoding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 351\u001b[0m config \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mload(f)\n\u001b[0;32m--> 353\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m \u001b[43mSIDTokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdirectory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtokenizer.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 355\u001b[0m tiger_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m(tokenizer\u001b[38;5;241m=\u001b[39mtokenizer, device\u001b[38;5;241m=\u001b[39mdevice, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mconfig)\n\u001b[1;32m 357\u001b[0m state_dict \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(directory, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m), map_location\u001b[38;5;241m=\u001b[39mdevice)\n",
+ "File \u001b[0;32m~/workspace/rectools/rectools/semantic/tokenizer/model.py:165\u001b[0m, in \u001b[0;36mSIDTokenizer.load\u001b[0;34m(cls, path, device)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mload\u001b[39m(\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mcls\u001b[39m,\n\u001b[1;32m 148\u001b[0m path: Union[\u001b[38;5;28mstr\u001b[39m, PathLike[\u001b[38;5;28mstr\u001b[39m]],\n\u001b[1;32m 149\u001b[0m device: Optional[\u001b[38;5;28mstr\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 150\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSIDTokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 151\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Load serialized SIDTokenizer.\u001b[39;00m\n\u001b[1;32m 152\u001b[0m \n\u001b[1;32m 153\u001b[0m \u001b[38;5;124;03m Parameters\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;124;03m A de-serialized tokenizer.\u001b[39;00m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 165\u001b[0m loaded_data \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 167\u001b[0m tok \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m(\n\u001b[1;32m 168\u001b[0m quantizer\u001b[38;5;241m=\u001b[39mloaded_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantizer\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 169\u001b[0m id2sid\u001b[38;5;241m=\u001b[39mloaded_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mid2sid\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 170\u001b[0m device\u001b[38;5;241m=\u001b[39mdevice,\n\u001b[1;32m 171\u001b[0m )\n\u001b[1;32m 173\u001b[0m \u001b[38;5;66;03m# Remind the codebooks to not re-init themselves\u001b[39;00m\n",
+ "File \u001b[0;32m~/.clearml/venvs-builds/3.12/lib/python3.12/site-packages/torch/serialization.py:1548\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _load(\n\u001b[1;32m 1541\u001b[0m opened_zipfile,\n\u001b[1;32m 1542\u001b[0m map_location,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1545\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpickle_load_args,\n\u001b[1;32m 1546\u001b[0m )\n\u001b[1;32m 1547\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m pickle\u001b[38;5;241m.\u001b[39mUnpicklingError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m-> 1548\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m pickle\u001b[38;5;241m.\u001b[39mUnpicklingError(_get_wo_message(\u001b[38;5;28mstr\u001b[39m(e))) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1549\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _load(\n\u001b[1;32m 1550\u001b[0m opened_zipfile,\n\u001b[1;32m 1551\u001b[0m map_location,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1554\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpickle_load_args,\n\u001b[1;32m 1555\u001b[0m )\n\u001b[1;32m 1556\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m mmap:\n",
+ "\u001b[0;31mUnpicklingError\u001b[0m: Weights only load failed. This file can still be loaded, to do so you have two options, \u001b[1mdo those steps only if you trust the source of the checkpoint\u001b[0m. \n\t(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.\n\t(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.\n\tWeightsUnpickler error: Unsupported global: GLOBAL rectools.semantic.tokenizer.rqvae.RQVAE was not an allowed global by default. Please use `torch.serialization.add_safe_globals([rectools.semantic.tokenizer.rqvae.RQVAE])` or the `torch.serialization.safe_globals([rectools.semantic.tokenizer.rqvae.RQVAE])` context manager to allowlist this global if you trust this class/function.\n\nCheck the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html."
+ ]
+ }
+ ],
+ "source": [
+ "import tempfile\n",
+ "import os\n",
+ "\n",
+ "with tempfile.TemporaryDirectory() as tmpdir:\n",
+ " save_path = os.path.join(tmpdir, \"tiger_model\")\n",
+ " tiger.save(save_path)\n",
+ " print(f\"Saved to {save_path}\")\n",
+ " print(f\"Contents: {os.listdir(save_path)}\")\n",
+ "\n",
+ " loaded_tiger = TIGERModel.load(save_path)\n",
+ " print(f\"\\nLoaded model with {len(loaded_tiger.tokenizer)} item SIDs\")\n",
+ " print(f\"Codebook sizes: {loaded_tiger.codebook_sizes}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "summary",
+ "metadata": {},
+ "source": [
+ "## Summary\n",
+ "\n",
+ "The TIGER pipeline in RecTools:\n",
+ "\n",
+ "| Step | Function / Method | Input | Output |\n",
+ "|------|-------------------|-------|--------|\n",
+ "| Filtering | `k_core(df)` | Interactions DataFrame | Filtered DataFrame |\n",
+ "| Embedding | `encode_texts(texts, model)` | List of text strings | numpy embedding matrix |\n",
+ "| Tokenization | `SIDTokenizer(...).fit(ids, embs)` | Item IDs + embeddings | `SIDTokenizer` |\n",
+ "| Splitting | `loo_split(df)` | Interactions DataFrame | train, val, test DataFrames |\n",
+ "| Model creation | `TIGERModel(tokenizer)` | `SIDTokenizer` + hyperparams | `TIGERModel` |\n",
+ "| Training | `model.fit(train, val)` | Train + val DataFrames | Trained model (in-place) |\n",
+ "| Evaluation | `model.evaluate(test)` | Test DataFrame | Metrics dict |\n",
+ "| Prediction | `model.predict(df)` | Interactions DataFrame | Recommendations DataFrame |\n",
+ "| Persistence | `model.save(dir)` / `TIGERModel.load(dir)` | Directory path | Saved/loaded model |"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "3.12",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/examples/tutorials/utils.py b/examples/tutorials/utils.py
index 8fdf8a8a..b8ea7d54 100644
--- a/examples/tutorials/utils.py
+++ b/examples/tutorials/utils.py
@@ -16,6 +16,7 @@
import os
import typing as tp
import warnings
+from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
@@ -329,3 +330,61 @@ def get_results(path_to_load_res: str, metrics_to_show: tp.List[str], show_loss:
)
pivot_results.columns = pivot_results.columns.droplevel(1)
return pivot_results[metrics_to_show]
+
+
+def find_conflicts(id2sid: dict) -> dict[tuple[int, ...], list[int]]:
+ """Return a mapping from SID -> list of item IDs for every SID shared by 2+ items.
+
+ Parameters
+ ----------
+ id2sid : dict
+ Mapping from item ID to SID tuple, e.g. from ``SIDTokenizer.id2sid``.
+
+ Returns
+ -------
+ dict[tuple[int, ...], list[int]]
+ Conflicting SIDs mapped to the item IDs that share them.
+ """
+ sid2ids: dict[tuple[int, ...], list[int]] = defaultdict(list)
+ for item_id, sid in id2sid.items():
+ sid2ids[sid].append(item_id)
+ return {sid: ids for sid, ids in sid2ids.items() if len(ids) > 1}
+
+
+def find_conflicts_df(
+ id2sid: dict,
+ metadata: pd.DataFrame,
+ item_col: str = "item_id",
+ text_col: str = "text",
+) -> pd.DataFrame:
+ """Find conflicting SIDs and return a DataFrame enriched with item metadata.
+
+ Parameters
+ ----------
+ id2sid : dict
+ Mapping from item ID to SID tuple, e.g. from ``SIDTokenizer.id2sid``.
+ metadata : pd.DataFrame
+ Item metadata with at least ``item_col`` and ``text_col`` columns.
+ item_col : str
+ Name of the item ID column in ``metadata``.
+ text_col : str
+ Name of the text column in ``metadata``.
+
+ Returns
+ -------
+ pd.DataFrame
+ Columns: ``SID``, ``item_id``, ``text``. One row per conflicting item,
+ sorted by cluster size descending. Empty if no conflicts.
+ """
+ conflicts = find_conflicts(id2sid)
+ if not conflicts:
+ return pd.DataFrame(columns=["SID", "item_id", "text"])
+
+ id2text = dict(zip(metadata[item_col], metadata[text_col]))
+
+ rows = []
+ for sid, ids in sorted(conflicts.items(), key=lambda kv: -len(kv[1])):
+ for item_id in ids:
+ rows.append({"SID": sid, "item_id": item_id, "text": id2text.get(item_id, "")})
+
+ return pd.DataFrame(rows, columns=["SID", "item_id", "text"])
diff --git a/poetry.lock b/poetry.lock
index b1b15e5b..9c0bcb2b 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,15 @@
-# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand.
+
+[[package]]
+name = "absl-py"
+version = "2.3.1"
+description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d"},
+ {file = "absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9"},
+]
[[package]]
name = "aiohappyeyeballs"
@@ -1631,6 +1642,17 @@ files = [
{file = "imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a"},
]
+[[package]]
+name = "immutabledict"
+version = "4.3.1"
+description = "Immutable wrapper around dictionaries (a fork of frozendict)"
+optional = true
+python-versions = "<4.0,>=3.8"
+files = [
+ {file = "immutabledict-4.3.1-py3-none-any.whl", hash = "sha256:c9facdc0ff30fdb8e35bd16532026cac472a549e182c94fa201b51b25e4bf7bf"},
+ {file = "immutabledict-4.3.1.tar.gz", hash = "sha256:f844a669106cfdc73f47b1a9da003782fb17dc955a54c80972e0d93d1c63c514"},
+]
+
[[package]]
name = "implicit"
version = "0.7.2"
@@ -1987,6 +2009,83 @@ files = [
{file = "jupyterlab_widgets-3.0.16.tar.gz", hash = "sha256:423da05071d55cf27a9e602216d35a3a65a3e41cdf9c5d3b643b814ce38c19e0"},
]
+[[package]]
+name = "k-means-constrained"
+version = "0.7.3"
+description = "K-Means clustering constrained with minimum and maximum cluster size"
+optional = true
+python-versions = ">=3.8"
+files = [
+ {file = "k-means-constrained-0.7.3.tar.gz", hash = "sha256:2b5bf8317614d749cf3da6b6f4adf52dee34cdf0dc16b8a26b40d72f1a000dff"},
+ {file = "k_means_constrained-0.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8d8304b1688d16d31b2803a94504484524ee8977fa2eaac53e9a33a4f1ef0c9e"},
+ {file = "k_means_constrained-0.7.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b1d79e9032c0471f522a964fa1804f776a980fc932c2793d0d256d33b762119"},
+ {file = "k_means_constrained-0.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e1736e9b8a6dfd2af0de5200ef03fc5b13ef55907e3f5431289ee519d704ae"},
+ {file = "k_means_constrained-0.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15d838a941043fe6480ff36f7d8e9128fb73f193248c42b6fdbaf2aa5b8ce00e"},
+ {file = "k_means_constrained-0.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:521881df8362dce22875e431150f74a4f6f17cf798bbe2f3711d145f97197335"},
+ {file = "k_means_constrained-0.7.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:04e07d7fecc36af499f13efd0925712a7dc8382fb10f5582d780f3c7a2304c71"},
+ {file = "k_means_constrained-0.7.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f5034fa54879ff7e15882e3d4cfe572e24549447717a6fcf724bbe1520e06fe"},
+ {file = "k_means_constrained-0.7.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08288fe3011783a1ffd0c18cf314ea604e9d1ba68d601ab17584428097c15ad4"},
+ {file = "k_means_constrained-0.7.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b279f4807314bbeabf5bd5fb78de382607d8e97b31bb277d9a09d9a8b412b969"},
+ {file = "k_means_constrained-0.7.3-cp311-cp311-win_amd64.whl", hash = "sha256:d6419494fdfbad3448f2c6ab8fb9492cadaca723a3b3d1a7acbacef894095a17"},
+ {file = "k_means_constrained-0.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f229dc1791a9d792b68b3c79555dc7d94462b56fd2ffc24914cb3c115a12dbb6"},
+ {file = "k_means_constrained-0.7.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d116ad48100c6ee3f57ee4c675792a4a47569dbe817c25899aa29bcdedb184aa"},
+ {file = "k_means_constrained-0.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b6b84ba06a79c3d50e787770a7ba02f5ba50df3f2a3dc1a106e5155de656b8"},
+ {file = "k_means_constrained-0.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da68417759907b0abf127b876f2b92264a0d6fc056165d8080c1557dd79e339a"},
+ {file = "k_means_constrained-0.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:98beab00776bd887f4fe28322b1183f85da3d4e2e2d8f2960d8d9edf21412a3b"},
+ {file = "k_means_constrained-0.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e4fb34b8ac2c9857505f8d34b8f77e2dd21e2cba28e811f4bb9a7e37844ba6c7"},
+ {file = "k_means_constrained-0.7.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3965bc53eddb00777fc0fd34bc4332fde29c283c93ac82326a3678b175e180e5"},
+ {file = "k_means_constrained-0.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1354253318ed7c688ac6f53d6e29af8065e5a64965bb535279cd6e40e66c63af"},
+ {file = "k_means_constrained-0.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26c07b374ed4dda4e396442e4661aed6442294858c8cf38715a0b44f816da590"},
+ {file = "k_means_constrained-0.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:a4ffd8222d09f392c2c8822f489093c509f2c0d10e10a0d6f0e42a0abff4061d"},
+]
+
+[package.dependencies]
+joblib = "*"
+numpy = ">=1.23.0"
+ortools = ">=9.4.1874"
+scipy = ">=1.6.3"
+six = "*"
+
+[package.extras]
+dev = ["bump2version", "nose", "numpydoc", "pandas (>=1.0.4)", "pytest (>=5.1)", "scikit-learn (>=0.24.2)", "sphinx", "sphinx-rtd-theme", "twine"]
+docs = ["sphinx", "sphinx-rtd-theme"]
+
+[[package]]
+name = "k-means-constrained"
+version = "0.7.6"
+description = "K-Means clustering constrained with minimum and maximum cluster size"
+optional = true
+python-versions = "*"
+files = [
+ {file = "k_means_constrained-0.7.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:fe6dc2acfeafda2e7bdb0c3d290dbb0b4b57078ad47990a1ba783a26ed374286"},
+ {file = "k_means_constrained-0.7.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6fe7b59e45b8ef21720e16a4b75c69e242c5b447bc4ce230c0c33b9caf91f87b"},
+ {file = "k_means_constrained-0.7.6-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3629dfe2fa0629cc1ab8864f84e1b19651ab272e4016af0b95791215d0573692"},
+ {file = "k_means_constrained-0.7.6-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5fb57b1e6c4c2bdc2f80b9c993f97345bfa8987269da4ca50e860acd93f18126"},
+ {file = "k_means_constrained-0.7.6-cp310-cp310-win_amd64.whl", hash = "sha256:ad191818fe2c2a216ac1bc8a69790d3654cffaa88797e2f48c67f3f5c5315b39"},
+ {file = "k_means_constrained-0.7.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c841ea6629786697bb486ceeac0bec59d016794b468110ca6d521446ba12c995"},
+ {file = "k_means_constrained-0.7.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef667d9191321cef7ef4ef2132c82c017b44e8fb334a91e975167024685b89be"},
+ {file = "k_means_constrained-0.7.6-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eaaa77929eb6728beb16a2bd755104eee8068d807a1cbfce9a8b206f0cb3b2e2"},
+ {file = "k_means_constrained-0.7.6-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:28425e8f70d6035d482ff00a661c69c7a3e400403583916d596306daf3e79199"},
+ {file = "k_means_constrained-0.7.6-cp311-cp311-win_amd64.whl", hash = "sha256:a5783224f20ef4703dba4b21a2d0dc22e3b314a6ed5c064f6767021bd932a93b"},
+ {file = "k_means_constrained-0.7.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:05734675db6d5d43888394690640d283a2c1d5529435ecba87e172dffc806189"},
+ {file = "k_means_constrained-0.7.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34e08d5b707905df635ba1145e639269745c56c347157717265c3e48c2a04ba6"},
+ {file = "k_means_constrained-0.7.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9eb967fbf28b6bd91fdfed3bc543fae4973d8a474a3aa6dabd141236ebb36682"},
+ {file = "k_means_constrained-0.7.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:736d97ff642dfaaf0df4b63632b827ba4b1466794e111970fbfc99e2c52286a2"},
+ {file = "k_means_constrained-0.7.6-cp312-cp312-win_amd64.whl", hash = "sha256:599eb9e22e2c99b50dd7d9bf434789de7c0831972c508789be40a25b1796469b"},
+ {file = "k_means_constrained-0.7.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4cdf3de34fb7f24a78af05e3eaec063747258b8f71513faf85b72e9f601bc351"},
+ {file = "k_means_constrained-0.7.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:24ae15921827e655cfc13e6c9138e7f6b9be3bc488b2030eb19461de66aff42e"},
+ {file = "k_means_constrained-0.7.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2e2d9842493bf934eafadb68b689b715e1b2880ba58af05263ee2e11ccb3c6d7"},
+ {file = "k_means_constrained-0.7.6-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:876391bb2f1a79ece65d1dba8a76545da09b9ee4cb287f467e861c1043caaf4f"},
+ {file = "k_means_constrained-0.7.6-cp313-cp313-win_amd64.whl", hash = "sha256:a828bbc81dde339e7ae97934bf9dcb43c5b49884d1d9c8fb0ffa117f5c998233"},
+]
+
+[package.dependencies]
+joblib = "*"
+numpy = ">=2.1.1"
+ortools = ">=9.11.4210"
+scipy = ">=1.14.1"
+six = "*"
+
[[package]]
name = "kiwisolver"
version = "1.4.7"
@@ -3332,6 +3431,113 @@ files = [
{file = "nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e"},
]
+[[package]]
+name = "ortools"
+version = "9.14.6206"
+description = "Google OR-Tools python libraries and modules"
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "ortools-9.14.6206-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:6e2364edd1577cd094e7c7121ec5fb0aa462a69a78ce29cdc40fa45943ff0091"},
+ {file = "ortools-9.14.6206-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:164b726b4d358ae68a018a52ff1999c0646d6f861b33676c2c83e2ddb60cfa13"},
+ {file = "ortools-9.14.6206-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ebb0e210969cc3246fe78dadf9038936a3a18edc8156e23a394e2bbcec962431"},
+ {file = "ortools-9.14.6206-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:174de2f04c106c7dcc5989560f2c0e065e78fba0ad0d1fd029897582f4823c3a"},
+ {file = "ortools-9.14.6206-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e6d994ebcf9cbdda1e20a75662967124e7e6ffd707c7f60b2db1a11f2104d384"},
+ {file = "ortools-9.14.6206-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5763472f8b05072c96c36c4eafadd9f6ffcdab38a81d8f0142fc408ad52a4342"},
+ {file = "ortools-9.14.6206-cp310-cp310-win_amd64.whl", hash = "sha256:6711516f837f06836ff9fda66fe4337b88c214f2ba6a921b84d3b05876f1fa8c"},
+ {file = "ortools-9.14.6206-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:8bcd8481846090585a4fac82800683555841685c49fa24578ad1e48a37918568"},
+ {file = "ortools-9.14.6206-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5af2bbf2fff7d922ba036e27d7ff378abecb24749380c86a77fa6208d5ba35cd"},
+ {file = "ortools-9.14.6206-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a6ab43490583c4bbf0fff4e51bb1c15675d5651c2e8e12ba974fd08e8c05a48f"},
+ {file = "ortools-9.14.6206-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9aa2c0c50a765c6a060960dcb0207bd6aeb6341f5adacb3d33e613b7e7409428"},
+ {file = "ortools-9.14.6206-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:64ec63fd92125499e9ca6b72700406dda161eefdfef92f04c35c5150391f89a4"},
+ {file = "ortools-9.14.6206-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8651008f05257471f45a919ade5027afa12ab6f7a4fdf0a8bcc18c92032f8571"},
+ {file = "ortools-9.14.6206-cp311-cp311-win_amd64.whl", hash = "sha256:ca60877830a631545234e83e7f6bd55830334a4d0c2b51f1669b1f2698d58b84"},
+ {file = "ortools-9.14.6206-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:e38c8c4a184820cbfdb812a8d484f6506cf16993ce2a95c88bc1c9d23b17c63e"},
+ {file = "ortools-9.14.6206-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:db685073cbed9f8bfaa744f5e883f3dea57c93179b0abe1788276fd3b074fa61"},
+ {file = "ortools-9.14.6206-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d4bfb8bffb29991834cf4bde7048ca8ee8caed73e8dd21e5ec7de99a33bbfea0"},
+ {file = "ortools-9.14.6206-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eb464a698837e7f90ca5f9b3d748b6ddf553198a70032bc77824d1cd88695d2b"},
+ {file = "ortools-9.14.6206-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8f33deaeb7c3dda8ca1d29c5b9aa9c3a4f2ca9ecf34f12a1f809bb2995f41274"},
+ {file = "ortools-9.14.6206-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:086e7c2dc4f23efffb20a5e20f618c7d6adb99b2d94f684cab482387da3bc434"},
+ {file = "ortools-9.14.6206-cp312-cp312-win_amd64.whl", hash = "sha256:17c13b0bfde17ac57789ad35243edf1318ecd5db23cf949b75ab62480599f188"},
+ {file = "ortools-9.14.6206-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:8d0df7eef8ba53ad235e29018389259bad2e667d9594b9c2a412ed6a5756bd4e"},
+ {file = "ortools-9.14.6206-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:57dfe10844ce8331634d4723040fe249263fd490407346efc314c0bc656849b5"},
+ {file = "ortools-9.14.6206-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5c0c2c00a6e5d5c462e76fdda7dbd40d0f9139f1df4211d34b36906696248020"},
+ {file = "ortools-9.14.6206-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:38044cf39952d93cbcc02f6acdbe0a9bd3628fbf17f0d7eb0374060fa028c22e"},
+ {file = "ortools-9.14.6206-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:98564de773d709e1e49cb3c32f6917589c314f047786d88bd5f324c0eb7be96e"},
+ {file = "ortools-9.14.6206-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:80528b0ac72dc3de00cbeef2ce028517a476450b5877b1cda1b8ecb9fa98505e"},
+ {file = "ortools-9.14.6206-cp313-cp313-win_amd64.whl", hash = "sha256:47b1b15dcb085d32c61621b790259193aefa9e4577abadf233d47fbe7d0b81ef"},
+ {file = "ortools-9.14.6206-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d26a0f9ed97ef9d3384a9069923585f5f974c3fde555a41f4d6381fbe7840bc4"},
+ {file = "ortools-9.14.6206-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d40d8141667d47405f296a9f687058c566d7816586e9a672b59e9fcec8493133"},
+ {file = "ortools-9.14.6206-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:aefea81ed81aa937873efc520381785ed65380e52917f492ab566f46bbb5660d"},
+ {file = "ortools-9.14.6206-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f044bb277db3ab6a1b958728fe1cf14ca87c3800d67d7b321d876b48269340f6"},
+ {file = "ortools-9.14.6206-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:070dc7cebfa0df066acb6b9a6d02339351be8f91b2352b782ee7f40412207e20"},
+ {file = "ortools-9.14.6206-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5eb558a03b4ada501ecdea7b89f0d3bdf2cc6752e1728759ccf27923f592a8c2"},
+ {file = "ortools-9.14.6206-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:646329fa74a5c48c591b7fabfd26743f6d2de4e632b3b96ec596c47bfe19177a"},
+ {file = "ortools-9.14.6206-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aa5161924f35b8244295acd0fab2a8171bb08ef8d5cfaf1913a21274475704cc"},
+ {file = "ortools-9.14.6206-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:e253526a026ae194aed544a0d065163f52a0c9cb606a1061c62df546877d5452"},
+ {file = "ortools-9.14.6206-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:dcb496ef633d884036770783f43bf8a47ff253ecdd8a8f5b95f00276ec241bfd"},
+ {file = "ortools-9.14.6206-cp39-cp39-win_amd64.whl", hash = "sha256:2733f635675de631fdc7b1611878ec9ee2f48a26434b7b3c07d0a0f535b92e03"},
+]
+
+[package.dependencies]
+absl-py = ">=2.0.0"
+immutabledict = ">=3.0.0"
+numpy = ">=1.13.3"
+pandas = ">=2.0.0"
+protobuf = ">=6.31.1,<6.32"
+typing-extensions = ">=4.12"
+
+[[package]]
+name = "ortools"
+version = "9.15.6755"
+description = "Google OR-Tools python libraries and modules"
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "ortools-9.15.6755-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:e4559603031ed371c5d86b1e9357fa49fb89236452e4b9bc429a0cf4a2fab05d"},
+ {file = "ortools-9.15.6755-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5bb2b434f4ae01ce81813d01db722d9dedcc452aede681211ee4d4df8963a410"},
+ {file = "ortools-9.15.6755-cp310-cp310-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b26655d25ab28030aef30e675e24d96d35940974de3a70ace01cf82ca301b69"},
+ {file = "ortools-9.15.6755-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03424136aa48555e7f4d1bc73edeb99f80ec35a2e5f17700e2072640344980fc"},
+ {file = "ortools-9.15.6755-cp310-cp310-win_amd64.whl", hash = "sha256:4f4964f8ed47ac76b5cfd23238618299f5a3c289d8e0ed66a75885ba9766eb6f"},
+ {file = "ortools-9.15.6755-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:55e291560d2fdb9590656cbee06ba99ee7f2476bd7d316ff757eeab33e9b20d6"},
+ {file = "ortools-9.15.6755-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e51ae55569650e5381fd6e50c655ccf6368a9532f5720ea41396bb90e0247a21"},
+ {file = "ortools-9.15.6755-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3bcccd15ef3fc6ac10bfa11630ba6dfe437d4fd1374a5b33f4773b7fee0f877"},
+ {file = "ortools-9.15.6755-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7a85a68ffb3fc1967e78624f40d3707aae459e82e3f2d9fe02e91788b3c7bf2a"},
+ {file = "ortools-9.15.6755-cp311-cp311-win_amd64.whl", hash = "sha256:781fb09d6c9f46015291f706bd7c7e0815db1bec6e92c74716342fb7ea2d0532"},
+ {file = "ortools-9.15.6755-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:ae1c6e1fd844b4d756b22eb6c0ed574ea4342ee206d807c4f903039e748228fa"},
+ {file = "ortools-9.15.6755-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e16686c2b457fa6242c474ab890ee1712347ab53678e0d2fab307ae03e97a4b"},
+ {file = "ortools-9.15.6755-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3cd6bec0a2e00e3891a53e3b436f45a1000269f302085572f49e9856b7f8eaf0"},
+ {file = "ortools-9.15.6755-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:033836c0eb33bc72697a299e0caedbb25fc9d1cee0b13832d69cb30405f57b3e"},
+ {file = "ortools-9.15.6755-cp312-cp312-win_amd64.whl", hash = "sha256:487796301fd9dad55f9cf21f9313c834697f74306d1a59f002e152862f8eb1b5"},
+ {file = "ortools-9.15.6755-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:27a10474e62c9dceed37cfa0e4845c5ffaf792138ebf5b61483771b96f1290b6"},
+ {file = "ortools-9.15.6755-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:076565b803c85c4f87863e0616f537dd37f99c03e6f092e4068404f7b425d2b0"},
+ {file = "ortools-9.15.6755-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b85bd20259b146abce5e0721ce1bfd8fd273efc904216aa3be178c31b6d34057"},
+ {file = "ortools-9.15.6755-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ebd5aea00374e3aad7a78de59058aca5e871a26a3c385cd0860ef1d685d03c9a"},
+ {file = "ortools-9.15.6755-cp313-cp313-win_amd64.whl", hash = "sha256:caac1d48b967adb877da2abcaf82c28f0f908a7cc208a6a1bbe01bc69590816c"},
+ {file = "ortools-9.15.6755-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:82b4a8e6e4f9380b453ab5fa4382ea7ee91e628f9b8be89d9ad760b33fca3323"},
+ {file = "ortools-9.15.6755-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2d1f2fb2088e8953ccb902e68ffd06032cce0c7dcf7268b6135f3b6c553ca52b"},
+ {file = "ortools-9.15.6755-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:acdf06a167933307608e7eba23a9490255933504df44c8de5f62c48656c29688"},
+ {file = "ortools-9.15.6755-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1a0677270b0cd317a6b8dae42514264eaf5da5756c5bc7215eeea409424577df"},
+ {file = "ortools-9.15.6755-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:899b92afe3f775ab5867b9a8aa2850f81f2d95232db9b4ceec3456d69e6b8528"},
+ {file = "ortools-9.15.6755-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7181183cdcafe2b0d83ca5505b65048c7953dc7b5ad479361dded607964cc1b3"},
+ {file = "ortools-9.15.6755-cp314-cp314-win_amd64.whl", hash = "sha256:afabb869e5fabeb704bd8147b22bf8139dee042e55fabd0d447a996428009e0c"},
+ {file = "ortools-9.15.6755-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d9d07cddca201e25e2e219006a9d6cda10c7e9ee2c712c50d19d508f9ed8a888"},
+ {file = "ortools-9.15.6755-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:990838ad66a052e72a50e69da500878710e3420e91717fe88bf3071995caba9e"},
+ {file = "ortools-9.15.6755-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:73b229dbc2b225441cb3bc5790ea8d55f14e3cd7f32d5185784f60b102308457"},
+ {file = "ortools-9.15.6755-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8881e9620bf0bf8303891e171feb03d6e86c75a05fb9325a09ae7fbf93093f4a"},
+ {file = "ortools-9.15.6755-cp39-cp39-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:87c73acda29f03ded74c7d2388f6efcbe45fa45a3f2bae4d85e1b5f1cc4cd9c1"},
+ {file = "ortools-9.15.6755-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:76150c4dd5927d0ab138344a5121b1f63e32501371ca93e632e2e10d8261064b"},
+ {file = "ortools-9.15.6755-cp39-cp39-win_amd64.whl", hash = "sha256:d72c136fd6e4b112bf154680290490da0aca60b7ddfc4163581c069c67016d4a"},
+]
+
+[package.dependencies]
+absl-py = ">=2.0.0"
+immutabledict = ">=3.0.0"
+numpy = ">=2.0.2"
+pandas = ">=2.0.0"
+protobuf = ">=6.33.1,<6.34"
+typing-extensions = ">=4.12"
+
[[package]]
name = "packaging"
version = "26.0"
@@ -3874,6 +4080,43 @@ files = [
{file = "propcache-0.4.1.tar.gz", hash = "sha256:f48107a8c637e80362555f37ecf49abe20370e557cc4ab374f04ec4423c97c3d"},
]
+[[package]]
+name = "protobuf"
+version = "6.31.1"
+description = ""
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "protobuf-6.31.1-cp310-abi3-win32.whl", hash = "sha256:7fa17d5a29c2e04b7d90e5e32388b8bfd0e7107cd8e616feef7ed3fa6bdab5c9"},
+ {file = "protobuf-6.31.1-cp310-abi3-win_amd64.whl", hash = "sha256:426f59d2964864a1a366254fa703b8632dcec0790d8862d30034d8245e1cd447"},
+ {file = "protobuf-6.31.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:6f1227473dc43d44ed644425268eb7c2e488ae245d51c6866d19fe158e207402"},
+ {file = "protobuf-6.31.1-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:a40fc12b84c154884d7d4c4ebd675d5b3b5283e155f324049ae396b95ddebc39"},
+ {file = "protobuf-6.31.1-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:4ee898bf66f7a8b0bd21bce523814e6fbd8c6add948045ce958b73af7e8878c6"},
+ {file = "protobuf-6.31.1-cp39-cp39-win32.whl", hash = "sha256:0414e3aa5a5f3ff423828e1e6a6e907d6c65c1d5b7e6e975793d5590bdeecc16"},
+ {file = "protobuf-6.31.1-cp39-cp39-win_amd64.whl", hash = "sha256:8764cf4587791e7564051b35524b72844f845ad0bb011704c3736cce762d8fe9"},
+ {file = "protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e"},
+ {file = "protobuf-6.31.1.tar.gz", hash = "sha256:d8cac4c982f0b957a4dc73a80e2ea24fab08e679c0de9deb835f4a12d69aca9a"},
+]
+
+[[package]]
+name = "protobuf"
+version = "6.33.6"
+description = ""
+optional = true
+python-versions = ">=3.9"
+files = [
+ {file = "protobuf-6.33.6-cp310-abi3-win32.whl", hash = "sha256:7d29d9b65f8afef196f8334e80d6bc1d5d4adedb449971fefd3723824e6e77d3"},
+ {file = "protobuf-6.33.6-cp310-abi3-win_amd64.whl", hash = "sha256:0cd27b587afca21b7cfa59a74dcbd48a50f0a6400cfb59391340ad729d91d326"},
+ {file = "protobuf-6.33.6-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:9720e6961b251bde64edfdab7d500725a2af5280f3f4c87e57c0208376aa8c3a"},
+ {file = "protobuf-6.33.6-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e2afbae9b8e1825e3529f88d514754e094278bb95eadc0e199751cdd9a2e82a2"},
+ {file = "protobuf-6.33.6-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c96c37eec15086b79762ed265d59ab204dabc53056e3443e702d2681f4b39ce3"},
+ {file = "protobuf-6.33.6-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:e9db7e292e0ab79dd108d7f1a94fe31601ce1ee3f7b79e0692043423020b0593"},
+ {file = "protobuf-6.33.6-cp39-cp39-win32.whl", hash = "sha256:bd56799fb262994b2c2faa1799693c95cc2e22c62f56fb43af311cae45d26f0e"},
+ {file = "protobuf-6.33.6-cp39-cp39-win_amd64.whl", hash = "sha256:f443a394af5ed23672bc6c486be138628fbe5c651ccbc536873d7da23d1868cf"},
+ {file = "protobuf-6.33.6-py3-none-any.whl", hash = "sha256:77179e006c476e69bf8e8ce866640091ec42e1beb80b213c3900006ecfba6901"},
+ {file = "protobuf-6.33.6.tar.gz", hash = "sha256:a6768d25248312c297558af96a9f9c929e8c4cee0659cb07e780731095f38135"},
+]
+
[[package]]
name = "psutil"
version = "7.2.2"
@@ -5480,10 +5723,17 @@ description = "Tensors and Dynamic neural networks in Python with strong GPU acc
optional = true
python-versions = ">=3.10"
files = [
- {file = "torch-2.10.0-1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c37fc46eedd9175f9c81814cc47308f1b42cfe4987e532d4b423d23852f2bf63"},
- {file = "torch-2.10.0-1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f699f31a236a677b3118bc0a3ef3d89c0c29b5ec0b20f4c4bf0b110378487464"},
- {file = "torch-2.10.0-1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6abb224c2b6e9e27b592a1c0015c33a504b00a0e0938f1499f7f514e9b7bfb5c"},
- {file = "torch-2.10.0-1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7350f6652dfd761f11f9ecb590bfe95b573e2961f7a242eccb3c8e78348d26fe"},
+ {file = "torch-2.10.0-2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2b980edd8d7c0a68c4e951ee1856334a43193f98730d97408fbd148c1a933313"},
+ {file = "torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f"},
+ {file = "torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574"},
+ {file = "torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e"},
+ {file = "torch-2.10.0-3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a1ff626b884f8c4e897c4c33782bdacdff842a165fee79817b1dd549fdda1321"},
+ {file = "torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac"},
+ {file = "torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6"},
+ {file = "torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b"},
+ {file = "torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49"},
+ {file = "torch-2.10.0-3-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b1d5e2aba4eb7f8e87fbe04f86442887f9167a35f092afe4c237dfcaaef6e328"},
+ {file = "torch-2.10.0-3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:0228d20b06701c05a8f978357f657817a4a63984b0c90745def81c18aedfa591"},
{file = "torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d"},
{file = "torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444"},
{file = "torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb"},
@@ -5972,9 +6222,10 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_it
type = ["pytest-mypy"]
[extras]
-all = ["catboost", "cupy-cuda12x", "cupy-cuda12x", "ipywidgets", "nbformat", "nmslib", "nmslib-metabrainz", "plotly", "pytorch-lightning", "pytorch-lightning", "rectools-lightfm", "torch", "torch", "torch"]
+all = ["catboost", "cupy-cuda12x", "cupy-cuda12x", "ipywidgets", "k-means-constrained", "nbformat", "nmslib", "nmslib-metabrainz", "plotly", "pytorch-lightning", "pytorch-lightning", "rectools-lightfm", "torch", "torch", "torch"]
catboost = ["catboost"]
cupy = ["cupy-cuda12x", "cupy-cuda12x"]
+k-means-constrained = ["k-means-constrained"]
lightfm = ["rectools-lightfm"]
nmslib = ["nmslib", "nmslib-metabrainz"]
torch = ["pytorch-lightning", "pytorch-lightning", "torch", "torch", "torch"]
@@ -5983,4 +6234,4 @@ visuals = ["ipywidgets", "nbformat", "plotly"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9, <3.14"
-content-hash = "16e669449fa37e7bbf6ec621fb42bad756aa55b77d503b26839e3d388b30fd0a"
+content-hash = "91bfbe4a6a0aeb25c5d7e90c999546db562f39b17501d74cd022a8475b0192c4"
diff --git a/pyproject.toml b/pyproject.toml
index 8423c7ef..042b85a2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -117,6 +117,7 @@ cupy-cuda12x = [
# and cupy-cuda12x, so we add the version restriction here manually to avoid
# installing older version of fastrlock which is incompatible with Python 3.13
fastrlock = {version = "^0.8.3", markers = "sys_platform != 'darwin'", optional = true}
+k-means-constrained = {version = "^0.7.0", optional = true}
[tool.poetry.extras]
lightfm = ["rectools-lightfm"]
@@ -125,6 +126,7 @@ torch = ["torch", "pytorch-lightning"]
visuals = ["ipywidgets", "plotly", "nbformat"]
cupy = ["cupy-cuda12x"]
catboost = ["catboost"]
+k-means-constrained = ["k-means-constrained"]
all = [
"rectools-lightfm",
"nmslib", "nmslib-metabrainz",
@@ -132,6 +134,7 @@ all = [
"catboost",
"ipywidgets", "plotly", "nbformat",
"cupy-cuda12x",
+ "k-means-constrained",
]
diff --git a/rectools/semantic/__init__.py b/rectools/semantic/__init__.py
new file mode 100644
index 00000000..4ac0f2ae
--- /dev/null
+++ b/rectools/semantic/__init__.py
@@ -0,0 +1,9 @@
+"""rectools.semantic -- generative recommenders with Semantic ID inputs"""
+
+from .config_options import LRScheduleType, OptimizerType, QuantizerType
+
+__all__ = [
+ "LRScheduleType",
+ "OptimizerType",
+ "QuantizerType",
+]
diff --git a/rectools/semantic/config_options.py b/rectools/semantic/config_options.py
new file mode 100644
index 00000000..77c64f23
--- /dev/null
+++ b/rectools/semantic/config_options.py
@@ -0,0 +1,5 @@
+import typing as tp
+
+OptimizerType = tp.Literal["adam", "adagrad", "adamw"]
+QuantizerType = tp.Literal["rqvae", "rkmeans"]
+LRScheduleType = tp.Literal["constant", "cosine", "linear"]
diff --git a/rectools/semantic/data_handling/__init__.py b/rectools/semantic/data_handling/__init__.py
new file mode 100644
index 00000000..4d6adc80
--- /dev/null
+++ b/rectools/semantic/data_handling/__init__.py
@@ -0,0 +1,12 @@
+"""Vectorized sequence preprocessing for TIGER generative recommender"""
+
+from .dataset import PaddingCollateFn, TIGERDataset
+from .k_core import k_core
+from .loo_split import loo_split
+
+__all__ = [
+ "PaddingCollateFn",
+ "TIGERDataset",
+ "k_core",
+ "loo_split",
+]
diff --git a/rectools/semantic/data_handling/dataset.py b/rectools/semantic/data_handling/dataset.py
new file mode 100644
index 00000000..ae96fbb0
--- /dev/null
+++ b/rectools/semantic/data_handling/dataset.py
@@ -0,0 +1,194 @@
+import typing as tp
+
+import numpy as np
+import pandas as pd
+import torch
+from torch.nn.utils.rnn import pad_sequence
+from torch.utils.data import Dataset as TorchDataset
+from tqdm.auto import tqdm
+
+from rectools.semantic.tokenizer import SIDTokenizer
+
+
+class TIGERDataset(TorchDataset): # pylint: disable=too-many-instance-attributes
+ """Dataset for the TIGER generative retrieval model.
+
+ Accepts a pandas DataFrame of interactions, builds per-user item
+ sequences, and tokenizes them into Semantic IDs.
+
+ In train mode (``eval_mode=False``), each sample is a
+ (history, next-item) pair:
+ - ``input_ids`` : 1-D flattened encoder token sequence
+ - ``dec_input`` : decoder input with BOS prepended
+ - ``labels`` : raw SID codes of the target item
+
+ In eval mode (``eval_mode=True``), each sample is:
+ - ``input_ids`` : 1-D flattened encoder token sequence (all items except the last)
+ - ``labels`` : raw item ID of the target (last item)
+
+ Parameters
+ ----------
+ interactions : pd.DataFrame
+ Interaction history with at least ``user_col`` and ``item_col`` columns.
+ tokenizer : SIDTokenizer
+ Trained tokenizer for converting item IDs to Semantic IDs.
+ codebook_sizes : list of int
+ Number of codes per RQ-VAE level.
+ max_length : int
+ Maximum number of history items per sequence.
+ codeword_offset : int
+ Offset added to codeword indices to reserve special tokens.
+ bos_token_id : int
+ Token ID used as beginning-of-sequence in the decoder.
+ only_last : bool
+ If True, only the last ``max_length + 1`` items per user are used.
+ If False, all subsequences of length >= 2 are generated.
+ eval_mode : bool
+ If True, the dataset produces evaluation samples (raw item ID as label).
+ max_users : int or None
+ If set, subsample to at most this many users.
+ user_col : str
+ Name of the user ID column.
+ item_col : str
+ Name of the item ID column.
+ timestamp_col : str
+ Name of the timestamp column. If present in the DataFrame,
+ interactions are sorted by (user_col, timestamp_col).
+ """
+
+ def __init__(
+ self,
+ interactions: pd.DataFrame,
+ tokenizer: SIDTokenizer,
+ codebook_sizes: tp.List[int],
+ max_length: int = 20,
+ codeword_offset: int = 2,
+ bos_token_id: int = 1,
+ only_last: bool = True,
+ eval_mode: bool = False,
+ max_users: tp.Optional[int] = None,
+ user_col: str = "user_id",
+ item_col: str = "item_id",
+ timestamp_col: str = "timestamp",
+ ) -> None:
+ self.tokenizer = tokenizer
+ self.codebook_sizes = codebook_sizes
+ self.sid_len = len(codebook_sizes)
+ self.max_length = max_length
+ self.codeword_offset = codeword_offset
+ self.bos_token_id = bos_token_id
+ self.only_last = only_last
+ self.eval_mode = eval_mode
+ self.max_users = max_users
+ self.user_col = user_col
+ self.item_col = item_col
+ self.timestamp_col = timestamp_col
+
+ self.interactions = interactions
+
+ self.offsets = np.cumsum([codeword_offset] + codebook_sizes)[:-1]
+
+ df = interactions
+ if timestamp_col in df.columns:
+ df = df.sort_values([user_col, timestamp_col])
+
+ user_sequences = df.groupby(user_col)[item_col].agg(list)
+
+ self._create_sequences(user_sequences)
+
+ def _create_sequences(self, user_sequences: pd.DataFrame) -> None: # pylint: disable=too-many-branches
+ user_ids = user_sequences.index.tolist()
+ if self.max_users is not None and len(user_ids) > self.max_users:
+ user_ids = np.random.choice(user_ids, size=self.max_users, replace=False).tolist()
+ user_sequences = user_sequences.loc[user_ids]
+
+ if self.eval_mode:
+ self.sequences: tp.List[tp.List[tp.Tuple[int, ...]]] = []
+ self.targets: tp.List[int] = []
+ for seq in user_sequences:
+ if len(seq) > self.max_length + 1:
+ tokenized = self.tokenizer.tokenize(seq[-self.max_length - 1 : -1])
+ else:
+ tokenized = self.tokenizer.tokenize(seq[:-1])
+ assert isinstance(tokenized, list)
+ self.sequences.append(tokenized)
+ self.targets.append(seq[-1])
+ elif self.only_last:
+ self.sequences = []
+ for seq in user_sequences:
+ seq = seq[-self.max_length - 1 :] if len(seq) > self.max_length + 1 else seq
+ tokenized = self.tokenizer.tokenize(seq)
+ assert isinstance(tokenized, list)
+ self.sequences.append(tokenized)
+ else:
+ self.sequences = []
+ for seq in tqdm(user_sequences, desc="Preparing the train data", ncols=120):
+ if len(seq) < 2:
+ continue
+ for t in range(1, len(seq)):
+ start = max(0, t - self.max_length)
+ tokenized = self.tokenizer.tokenize(seq[start : t + 1])
+ assert isinstance(tokenized, list)
+ self.sequences.append(tokenized)
+
+ def __len__(self) -> int:
+ return len(self.sequences)
+
+ def _sid_to_tokens(self, sid: tp.Tuple[int, ...]) -> tp.List[int]:
+ return [code + self.offsets[d] for d, code in enumerate(sid)]
+
+ def __getitem__(self, idx: int) -> tp.Dict[str, tp.Any]:
+ if self.eval_mode:
+ history_sids: tp.List[tp.Tuple[int, ...]] = self.sequences[idx]
+ target = self.targets[idx]
+
+ input_ids = np.array(
+ [tok for sid in history_sids for tok in self._sid_to_tokens(sid)],
+ dtype=np.int64,
+ )
+ return {"input_ids": input_ids, "labels": target}
+
+ item_sequence: tp.List[tp.Tuple[int, ...]] = self.sequences[idx]
+ history_sids = item_sequence[:-1]
+ target_sid: tp.Tuple[int, ...] = item_sequence[-1]
+
+ input_ids = np.array(
+ [tok for sid in history_sids for tok in self._sid_to_tokens(sid)],
+ dtype=np.int64,
+ )
+
+ target_tokens = np.array(self._sid_to_tokens(target_sid), dtype=np.int64)
+
+ dec_input = np.concatenate([[self.bos_token_id], target_tokens[:-1]]).astype(np.int64)
+
+ labels = np.array(target_sid, dtype=np.int64)
+
+ return {"input_ids": input_ids, "dec_input": dec_input, "labels": labels}
+
+
+class PaddingCollateFn:
+ """Automatically right pad user interaction sequences and labels with specified padding values.
+
+ Parameters
+ ----------
+ padding_value : int, optional
+ Value to pad input sequences with, by default 0
+ labels_padding_value : int, optional
+ Value to pad labels with, by default -100
+ """
+
+ def __init__(self, padding_value: int = 0, labels_padding_value: int = -100) -> None:
+ self.padding_value = padding_value
+ self.labels_padding_value = labels_padding_value
+
+ def __call__(self, batch: tp.List[tp.Dict[str, tp.Any]]) -> tp.Dict[str, torch.Tensor]:
+ """Apply padding to all fields in a batch."""
+ collated = {}
+ for key in batch[0]:
+ if np.isscalar(batch[0][key]):
+ collated[key] = torch.tensor([ex[key] for ex in batch])
+ continue
+ pad_val = self.labels_padding_value if key == "labels" else self.padding_value
+ values = [torch.tensor(ex[key]) for ex in batch]
+ collated[key] = pad_sequence(values, batch_first=True, padding_value=pad_val)
+ return collated
diff --git a/rectools/semantic/data_handling/k_core.py b/rectools/semantic/data_handling/k_core.py
new file mode 100644
index 00000000..15401f4c
--- /dev/null
+++ b/rectools/semantic/data_handling/k_core.py
@@ -0,0 +1,48 @@
+import pandas as pd
+
+
+def k_core(
+ interactions: pd.DataFrame,
+ user_min_interactions: int = 5,
+ item_min_interactions: int = 5,
+ user_col: str = "user_id",
+ item_col: str = "item_id",
+) -> pd.DataFrame:
+ """Filter user and item interactions using iterative k-core algorithm.
+
+ Removes users with fewer than ``user_min_interactions`` interactions
+ and items with fewer than ``item_min_interactions`` interactions,
+ repeating until no more removals are needed.
+
+ Parameters
+ ----------
+ interactions : pd.DataFrame
+ Interaction history. Must contain ``user_col`` and ``item_col`` columns.
+ user_min_interactions : int
+ Minimum number of interactions per user.
+ item_min_interactions : int
+ Minimum number of interactions per item.
+ user_col : str
+ Name of the user ID column.
+ item_col : str
+ Name of the item ID column.
+
+ Returns
+ -------
+ pd.DataFrame
+ Filtered interactions.
+ """
+ df = interactions
+ while True:
+ user_counts = df[user_col].value_counts()
+ item_counts = df[item_col].value_counts()
+ bad_users = user_counts[user_counts < user_min_interactions].index
+ bad_items = item_counts[item_counts < item_min_interactions].index
+
+ if len(bad_users) == 0 and len(bad_items) == 0:
+ break
+
+ df = df[~df[user_col].isin(bad_users) & ~df[item_col].isin(bad_items)]
+ df = df.reset_index(drop=True)
+
+ return df
diff --git a/rectools/semantic/data_handling/loo_split.py b/rectools/semantic/data_handling/loo_split.py
new file mode 100644
index 00000000..891107ed
--- /dev/null
+++ b/rectools/semantic/data_handling/loo_split.py
@@ -0,0 +1,47 @@
+import typing as tp
+
+import pandas as pd
+
+
+def loo_split(
+ interactions: pd.DataFrame,
+ user_col: str = "user_id",
+ timestamp_col: str = "timestamp",
+) -> tp.Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
+ """Leave-one-out split on an interactions DataFrame.
+
+ For each user, the last interaction goes to test, the second-to-last
+ to validation, and all earlier interactions to train. Users with
+ fewer than 3 interactions are dropped.
+
+ Parameters
+ ----------
+ interactions : pd.DataFrame
+ Interaction history.
+ user_col : str
+ Name of the user ID column.
+ timestamp_col : str
+ Name of the timestamp column. If present, interactions are
+ sorted by (user_col, timestamp_col) before splitting.
+
+ Returns
+ -------
+ train_df, val_df, test_df : pd.DataFrame
+ Train contains all but the last 2 interactions per user.
+ Val contains all but the last interaction per user.
+ Test contains all interactions.
+ """
+ df = interactions
+ if timestamp_col in df.columns:
+ df = df.sort_values([user_col, timestamp_col])
+
+ grouped = df.groupby(user_col)
+ valid_users = grouped.filter(lambda x: len(x) >= 3)
+ grouped = valid_users.groupby(user_col)
+
+ cumcounts = grouped.cumcount(ascending=False)
+ train_df = valid_users[cumcounts >= 2].reset_index(drop=True)
+ val_df = valid_users[cumcounts >= 1].reset_index(drop=True)
+ test_df = valid_users.reset_index(drop=True)
+
+ return train_df, val_df, test_df
diff --git a/rectools/semantic/metrics.py b/rectools/semantic/metrics.py
new file mode 100644
index 00000000..7a85fa3f
--- /dev/null
+++ b/rectools/semantic/metrics.py
@@ -0,0 +1,61 @@
+import typing as tp
+from collections import Counter
+
+
+def coverage_k(all_topk_items: tp.List[tp.List[int]], k: int, num_items: int) -> float:
+ """Calculate fraction of catalog items that appear in at least one user's top-K.
+
+ Parameters
+ ----------
+ all_topk_items : tp.List[tp.List[int]]
+ tp.List of all users' top-Ks
+ k : int
+ Users' top-K size
+ num_items : int
+ Number of unique items in dataset
+
+ Returns
+ -------
+ float
+ Coverage@k.
+ """
+ recommended = set()
+ for items in all_topk_items:
+ for item in items[:k]:
+ recommended.add(item)
+ return len(recommended) / num_items
+
+
+def gini_k(all_topk_items: tp.List[tp.List[int]], k: int) -> float:
+ """Gini coefficient over item recommendation frequencies in top-K.
+
+ 0 = every item recommended equally often, 1 = all recommendations
+ concentrate on a single item.
+
+ Parameters
+ ----------
+ all_topk_items : _type_
+ tp.List of all users' top-Ks
+ k : int
+ Users' top-K size
+
+ Returns
+ -------
+ float
+ Gini@k.
+ """
+ counts: Counter[int] = Counter()
+ for items in all_topk_items:
+ for item in items[:k]:
+ counts[item] += 1
+
+ if len(counts) == 0:
+ return 0.0
+
+ freqs = sorted(counts.values())
+ n = len(freqs)
+ cumulative = 0.0
+ for i, freq in enumerate(freqs):
+ cumulative += (2 * (i + 1) - n - 1) * freq
+ total = sum(freqs)
+ return cumulative / (n * total)
diff --git a/rectools/semantic/modules/__init__.py b/rectools/semantic/modules/__init__.py
new file mode 100644
index 00000000..046923c6
--- /dev/null
+++ b/rectools/semantic/modules/__init__.py
@@ -0,0 +1,12 @@
+"""torch.nn modules used in TIGER model and RQ-VAE"""
+
+from .mlp import MLP
+from .transformer_blocks import T5DecoderLayer, T5EncoderLayer, T5RMSNorm, T5RelativePositionBias
+
+__all__ = [
+ "MLP",
+ "T5DecoderLayer",
+ "T5EncoderLayer",
+ "T5RMSNorm",
+ "T5RelativePositionBias",
+]
diff --git a/rectools/semantic/modules/mlp.py b/rectools/semantic/modules/mlp.py
new file mode 100644
index 00000000..de7a624f
--- /dev/null
+++ b/rectools/semantic/modules/mlp.py
@@ -0,0 +1,66 @@
+import typing as tp
+
+from torch import Tensor, nn
+
+
+class MLP(nn.Module):
+ """An implementation of multilayer perceptron.
+
+ Parameters
+ ----------
+ input_dim : int
+ Input dimension.
+ hidden_dims : tp.List[int]
+ Dimensions of hidden layers.
+ out_dim : int
+ Output dimension.
+ dropout : float, optional
+ Dropout probability, by default 0.0
+ normalize : bool, optional
+ Whether to apply batch normalization, by default False
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ hidden_dims: tp.List[int],
+ out_dim: int,
+ dropout: float = 0.0,
+ normalize: bool = False,
+ ):
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.hidden_dims = hidden_dims
+ self.out_dim = out_dim
+ self.dropout = dropout
+ self.normalize = normalize
+
+ dims = [self.input_dim] + self.hidden_dims + [self.out_dim]
+
+ self.mlp = nn.Sequential()
+ for i, (in_d, out_d) in enumerate(zip(dims[:-1], dims[1:])):
+ self.mlp.append(nn.Linear(in_d, out_d, bias=False))
+
+ if self.normalize:
+ self.mlp.append(nn.BatchNorm1d(num_features=out_d))
+
+ if i != len(dims) - 2:
+ self.mlp.append(nn.ReLU())
+
+ if dropout != 0:
+ self.mlp.append(nn.Dropout(dropout))
+
+ self.apply(self._init_weights)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Run a forward pass."""
+ assert x.shape[-1] == self.input_dim, f"Invalid input dim: Expected {self.input_dim}, found {x.shape[-1]}"
+ return self.mlp(x)
+
+ # We just initialize the module with normal distribution as the paper said
+ def _init_weights(self, module: nn.Module) -> None:
+ if isinstance(module, nn.Linear):
+ nn.init.xavier_normal_(module.weight.data)
+ if module.bias is not None:
+ module.bias.data.fill_(0.0)
diff --git a/rectools/semantic/modules/transformer_blocks.py b/rectools/semantic/modules/transformer_blocks.py
new file mode 100644
index 00000000..280f4884
--- /dev/null
+++ b/rectools/semantic/modules/transformer_blocks.py
@@ -0,0 +1,609 @@
+import math
+import typing as tp
+
+import torch
+from torch import nn
+from torch.nn.functional import scaled_dot_product_attention
+
+
+class PointWiseFeedForward(nn.Module):
+ """
+ Feed-Forward network to introduce nonlinearity into the transformer model.
+ This implementation is the one used by SASRec authors.
+
+ Parameters
+ ----------
+ n_factors : int
+ Latent embeddings size.
+ n_factors_ff : int
+ How many hidden units to use in the network.
+ dropout_rate : float
+ Probability of a hidden unit to be zeroed.
+ activation: torch.nn.Module
+ Activation function module.
+ bias: bool, default ``True``
+ If ``True``, add bias to linear layers.
+ """
+
+ def __init__(
+ self,
+ n_factors: int,
+ n_factors_ff: int,
+ dropout_rate: float,
+ activation: torch.nn.Module,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias)
+ self.ff_dropout_1 = torch.nn.Dropout(dropout_rate)
+ self.ff_activation = activation
+ self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias)
+
+ def forward(self, seqs: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass.
+
+ Parameters
+ ----------
+ seqs : torch.Tensor
+ User sequences of item embeddings.
+
+ Returns
+ -------
+ torch.Tensor
+ User sequence that passed through all layers.
+ """
+ output = self.ff_activation(self.ff_linear_1(seqs))
+ fin = self.ff_linear_2(self.ff_dropout_1(output))
+ return fin
+
+
+class SwigluFeedForward(nn.Module):
+ """
+ Feed-Forward network to introduce nonlinearity into the transformer model.
+ This implementation is based on FuXi and LLama SwigLU https://arxiv.org/pdf/2502.03036,
+ LiGR https://arxiv.org/pdf/2502.03417
+
+ Parameters
+ ----------
+ n_factors : int
+ Latent embeddings size.
+ n_factors_ff : int
+ How many hidden units to use in the network.
+ dropout_rate : float
+ Probability of a hidden unit to be zeroed.
+ bias: bool, default ``True``
+ If ``True``, add bias to linear layers.
+ """
+
+ def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, bias: bool = True) -> None:
+ super().__init__()
+ self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias=bias)
+ self.ff_dropout_1 = torch.nn.Dropout(dropout_rate)
+ self.ff_activation = torch.nn.SiLU()
+ self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias=bias)
+ self.ff_linear_3 = nn.Linear(n_factors, n_factors_ff, bias=bias)
+
+ def forward(self, seqs: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass.
+
+ Parameters
+ ----------
+ seqs : torch.Tensor
+ User sequences of item embeddings.
+
+ Returns
+ -------
+ torch.Tensor
+ User sequence that passed through all layers.
+ """
+ output = self.ff_activation(self.ff_linear_1(seqs)) * self.ff_linear_3(seqs)
+ fin = self.ff_linear_2(self.ff_dropout_1(output))
+ return fin
+
+
+def init_feed_forward(
+ n_factors: int,
+ ff_factors_multiplier: int,
+ dropout_rate: float,
+ ff_activation: str,
+ bias: bool = True,
+) -> nn.Module:
+ """
+ Initialise Feed-Forward network with one of activation functions: "swiglu", "relu", "gelu".
+
+ Parameters
+ ----------
+ n_factors : int
+ Latent embeddings size.
+ ff_factors_multiplier : int
+ How many hidden units to use in the network.
+ dropout_rate : float
+ Probability of a hidden unit to be zeroed.
+ ff_activation : {"swiglu", "relu", "gelu"}
+ Activation function to use.
+ bias: bool, default ``True``
+ If ``True``, add bias to linear layers.
+
+ Returns
+ -------
+ nn.Module
+ Feed-Forward network.
+ """
+ if ff_activation == "swiglu":
+ return SwigluFeedForward(n_factors, n_factors * ff_factors_multiplier, dropout_rate, bias=bias)
+ if ff_activation == "gelu":
+ return PointWiseFeedForward(
+ n_factors,
+ n_factors * ff_factors_multiplier,
+ dropout_rate,
+ activation=torch.nn.GELU(),
+ bias=bias,
+ )
+ if ff_activation == "relu":
+ return PointWiseFeedForward(
+ n_factors,
+ n_factors * ff_factors_multiplier,
+ dropout_rate,
+ activation=torch.nn.ReLU(),
+ bias=bias,
+ )
+ raise ValueError(f"Unsupported ff_activation: {ff_activation}")
+
+
+class T5RMSNorm(nn.Module):
+ """
+ T5-style RMSNorm: normalise by root-mean-square, then scale.
+
+ Unlike ``nn.LayerNorm`` this has **no bias** and does **not**
+ subtract the mean -- it only divides by the RMS and applies a
+ learned scale (weight) vector.
+
+ Parameters
+ ----------
+ d_model : int
+ Feature dimension.
+ eps : float
+ Epsilon for numerical stability.
+ """
+
+ def __init__(self, d_model: int, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(d_model))
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Run a forward pass."""
+ rms = x.float().pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
+ return (x.float() * rms).to(x.dtype) * self.weight
+
+
+class T5RelativePositionBias(nn.Module):
+ """
+ Learned relative-position bias table with log-spaced bucketing, as
+ described in the T5 paper (Raffel et al., 2019).
+
+ The bias is added directly to the attention logits and is shared
+ across all layers of the encoder (or decoder).
+
+ Parameters
+ ----------
+ num_heads : int
+ Number of attention heads (one scalar bias per head).
+ bidirectional : bool
+ ``True`` for the encoder (positions attend in both directions),
+ ``False`` for the decoder (causal -- only attends to earlier
+ positions).
+ num_buckets : int
+ Number of distance buckets.
+ max_distance : int
+ Distances beyond this value are clamped to the last bucket.
+ """
+
+ def __init__(
+ self,
+ num_heads: int,
+ bidirectional: bool = True,
+ num_buckets: int = 32,
+ max_distance: int = 128,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ self.bidirectional = bidirectional
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
+
+ @staticmethod
+ def _relative_position_bucket(
+ relative_position: torch.Tensor,
+ bidirectional: bool,
+ num_buckets: int,
+ max_distance: int,
+ ) -> torch.Tensor:
+ """Map signed relative positions to bucket indices."""
+ ret = torch.zeros_like(relative_position)
+
+ if bidirectional:
+ num_buckets //= 2
+ ret += (relative_position > 0).long() * num_buckets
+ n = relative_position.abs()
+ else:
+ # Clamp future positions to 0 (causal -- no peeking ahead).
+ n = (-relative_position).clamp(min=0)
+
+ # Half the buckets are for exact small distances.
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ # The other half are for log-spaced larger distances.
+ val_if_large = (
+ max_exact
+ + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long()
+ )
+ val_if_large = val_if_large.clamp(max=num_buckets - 1)
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, query_len: int, key_len: int, device: torch.device) -> torch.Tensor:
+ """
+ Compute position bias.
+
+ Returns
+ -------
+ Tensor [1, num_heads, query_len, key_len]
+ Additive bias to be added to the attention logits.
+ """
+ q_pos = torch.arange(query_len, dtype=torch.long, device=device)
+ k_pos = torch.arange(key_len, dtype=torch.long, device=device)
+ # relative_position[i, j] = j - i (key pos minus query pos)
+ relative_position = k_pos.unsqueeze(0) - q_pos.unsqueeze(1)
+
+ buckets = self._relative_position_bucket(
+ relative_position,
+ bidirectional=self.bidirectional,
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ )
+ # [query_len, key_len, num_heads]
+ values = self.relative_attention_bias(buckets)
+ # -> [1, num_heads, query_len, key_len]
+ return values.permute(2, 0, 1).unsqueeze(0)
+
+
+class T5Attention(nn.Module):
+ """
+ Multi-head attention with separate key/value dimension, as in T5.
+
+ Unlike ``nn.MultiheadAttention`` where the per-head dimension is
+ always ``embed_dim // num_heads``, this module lets you choose
+ ``d_kv`` freely. Projections have no bias (following T5).
+
+ Uses ``scaled_dot_product_attention`` for the core computation,
+ which dispatches to FlashAttention / memory-efficient backends
+ automatically.
+
+ Parameters
+ ----------
+ d_model : int
+ Model (query input) dimension.
+ num_heads : int
+ Number of attention heads.
+ d_kv : int
+ Per-head dimension for keys and values (and queries).
+ dropout : float
+ Attention dropout probability (applied only during training).
+ kv_input_dim : int, optional
+ Input dimension for key/value tensors. Defaults to ``d_model``.
+ Set to a different value for cross-attention when the encoder
+ output dimension differs from the decoder's ``d_model``.
+ """
+
+ # nn.TransformerEncoder/Decoder access self_attn.batch_first
+ batch_first: bool = True
+
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ d_kv: int,
+ dropout: float = 0.0,
+ kv_input_dim: tp.Optional[int] = None,
+ ) -> None:
+ super().__init__()
+ if kv_input_dim is None:
+ kv_input_dim = d_model
+
+ self.num_heads = num_heads
+ self.d_kv = d_kv
+ self.inner_dim = num_heads * d_kv
+
+ self.q_proj = nn.Linear(d_model, self.inner_dim, bias=False)
+ self.k_proj = nn.Linear(kv_input_dim, self.inner_dim, bias=False)
+ self.v_proj = nn.Linear(kv_input_dim, self.inner_dim, bias=False)
+ self.o_proj = nn.Linear(self.inner_dim, d_model, bias=False)
+
+ self.dropout = dropout
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attn_mask: tp.Optional[torch.Tensor] = None,
+ key_padding_mask: tp.Optional[torch.Tensor] = None,
+ need_weights: bool = False,
+ is_causal: bool = False,
+ position_bias: tp.Optional[torch.Tensor] = None,
+ ) -> tp.Tuple[torch.Tensor, None]:
+ """
+ Parameters
+ ----------
+ query : Tensor [B, T_q, d_model]
+ key : Tensor [B, T_kv, kv_input_dim]
+ value : Tensor [B, T_kv, kv_input_dim]
+ attn_mask : Tensor, optional
+ Additive float mask broadcastable to ``(B, num_heads, T_q, T_kv)``.
+ ``-inf`` blocks attention.
+ key_padding_mask : BoolTensor [B, T_kv], optional
+ ``True`` at positions that should be masked (padding).
+ need_weights : bool
+ Ignored (always returns ``None`` for weights).
+ is_causal : bool
+ If ``True``, applies a causal mask inside SDPA.
+ position_bias : Tensor, optional
+ Additive relative-position bias broadcastable to
+ ``(B, num_heads, T_q, T_kv)``. Produced by
+ ``T5RelativePositionBias``.
+
+ Returns
+ -------
+ (output, None)
+ output : Tensor [B, T_q, d_model]
+ """
+ B, T_q, _ = query.shape
+ T_kv = key.shape[1]
+
+ # Project and reshape to (B, num_heads, T, d_kv)
+ q = self.q_proj(query).view(B, T_q, self.num_heads, self.d_kv).transpose(1, 2)
+ k = self.k_proj(key).view(B, T_kv, self.num_heads, self.d_kv).transpose(1, 2)
+ v = self.v_proj(value).view(B, T_kv, self.num_heads, self.d_kv).transpose(1, 2)
+
+ # Convert key_padding_mask [B, T_kv] to additive attention mask
+ # broadcastable to (B, num_heads, T_q, T_kv).
+ # The mask may arrive as bool (True=pad) or as a float tensor
+ # (0/-inf) -- nn.TransformerEncoder runs _canonical_mask which
+ # converts bool to float before calling the layer.
+ effective_mask = attn_mask
+ if key_padding_mask is not None:
+ if key_padding_mask.dtype == torch.bool:
+ pad_mask = torch.zeros(key_padding_mask.shape, dtype=q.dtype, device=q.device).masked_fill_(
+ key_padding_mask, float("-inf")
+ )
+ else:
+ pad_mask = key_padding_mask.to(dtype=q.dtype)
+ pad_mask = pad_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T_kv]
+ if effective_mask is not None:
+ effective_mask = effective_mask + pad_mask
+ else:
+ effective_mask = pad_mask
+
+ # Add relative position bias to attention logits.
+ if position_bias is not None:
+ if effective_mask is not None:
+ effective_mask = effective_mask + position_bias
+ else:
+ effective_mask = position_bias
+
+ # When an explicit mask is provided, don't also pass is_causal=True
+ # -- SDPA does not allow both at the same time.
+ if effective_mask is not None:
+ is_causal = False
+
+ out = scaled_dot_product_attention( # pylint: disable=not-callable
+ q,
+ k,
+ v,
+ attn_mask=effective_mask,
+ dropout_p=self.dropout if self.training else 0.0,
+ is_causal=is_causal,
+ ) # [B, num_heads, T_q, d_kv]
+
+ out = out.transpose(1, 2).contiguous().view(B, T_q, self.inner_dim)
+ out = self.o_proj(out)
+ return out, None
+
+
+class T5EncoderLayer(nn.Module):
+ """
+ Pre-norm Transformer encoder layer with T5-style attention.
+
+ Architecture per step::
+
+ x_norm = RMSNorm(x)
+ x = x + Dropout(SelfAttention(x_norm, x_norm, x_norm))
+ x_norm = RMSNorm(x)
+ x = x + Dropout(FFN(x_norm))
+
+ Parameters
+ ----------
+ d_model : int
+ Model dimension.
+ num_heads : int
+ Number of attention heads.
+ d_kv : int
+ Per-head key/value (and query) dimension.
+ dim_feedforward : int
+ Inner dimension of the feed-forward network.
+ dropout : float
+ Dropout probability.
+ activation : str
+ FFN activation (``"relu"`` or ``"gelu"``).
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ d_kv: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.self_attn = T5Attention(
+ d_model=d_model,
+ num_heads=num_heads,
+ d_kv=d_kv,
+ dropout=dropout,
+ )
+ self.norm1 = T5RMSNorm(d_model)
+ self.norm2 = T5RMSNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.activation: nn.Module
+ if activation == "relu":
+ self.activation = nn.ReLU()
+ elif activation == "gelu":
+ self.activation = nn.GELU()
+ else:
+ raise ValueError(f"Unsupported activation: {activation}")
+
+ def forward(
+ self,
+ src: torch.Tensor,
+ src_key_padding_mask: tp.Optional[torch.Tensor] = None,
+ position_bias: tp.Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Run a forward pass."""
+ x = src
+ x_norm = self.norm1(x)
+ attn_out, _ = self.self_attn(
+ x_norm,
+ x_norm,
+ x_norm,
+ key_padding_mask=src_key_padding_mask,
+ position_bias=position_bias,
+ )
+ x = x + self.dropout1(attn_out)
+
+ x_norm = self.norm2(x)
+ ff_out = self.linear2(self.activation(self.linear1(x_norm)))
+ x = x + self.dropout2(ff_out)
+ return x
+
+
+class T5DecoderLayer(nn.Module):
+ """
+ Pre-norm Transformer decoder layer with T5-style attention.
+
+ Architecture per step::
+
+ x_norm = RMSNorm(x)
+ x = x + Dropout(CausalSelfAttention(x_norm, x_norm, x_norm))
+ x_norm = RMSNorm(x)
+ x = x + Dropout(CrossAttention(x_norm, memory, memory))
+ x_norm = RMSNorm(x)
+ x = x + Dropout(FFN(x_norm))
+
+ Parameters
+ ----------
+ d_model : int
+ Model dimension.
+ num_heads : int
+ Number of attention heads.
+ d_kv : int
+ Per-head key/value (and query) dimension.
+ dim_feedforward : int
+ Inner dimension of the feed-forward network.
+ dropout : float
+ Dropout probability.
+ activation : str
+ FFN activation (``"relu"`` or ``"gelu"``).
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ d_kv: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: str = "relu",
+ ) -> None:
+ super().__init__()
+ self.self_attn = T5Attention(
+ d_model=d_model,
+ num_heads=num_heads,
+ d_kv=d_kv,
+ dropout=dropout,
+ )
+ self.cross_attn = T5Attention(
+ d_model=d_model,
+ num_heads=num_heads,
+ d_kv=d_kv,
+ dropout=dropout,
+ kv_input_dim=d_model,
+ )
+ self.norm1 = T5RMSNorm(d_model)
+ self.norm2 = T5RMSNorm(d_model)
+ self.norm3 = T5RMSNorm(d_model)
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.dropout3 = nn.Dropout(dropout)
+
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.activation: nn.Module
+ if activation == "relu":
+ self.activation = nn.ReLU()
+ elif activation == "gelu":
+ self.activation = nn.GELU()
+ else:
+ raise ValueError(f"Unsupported activation: {activation}")
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ tgt_mask: tp.Optional[torch.Tensor] = None,
+ memory_key_padding_mask: tp.Optional[torch.Tensor] = None,
+ position_bias: tp.Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Run a forward pass."""
+ x = tgt
+
+ # Self-attention (causal)
+ x_norm = self.norm1(x)
+ sa_out, _ = self.self_attn(
+ x_norm,
+ x_norm,
+ x_norm,
+ attn_mask=tgt_mask,
+ position_bias=position_bias,
+ )
+ x = x + self.dropout1(sa_out)
+
+ # Cross-attention (no position bias per T5 design)
+ x_norm = self.norm2(x)
+ ca_out, _ = self.cross_attn(
+ x_norm,
+ memory,
+ memory,
+ key_padding_mask=memory_key_padding_mask,
+ )
+ x = x + self.dropout2(ca_out)
+
+ # Feed-forward
+ x_norm = self.norm3(x)
+ ff_out = self.linear2(self.activation(self.linear1(x_norm)))
+ x = x + self.dropout3(ff_out)
+ return x
diff --git a/rectools/semantic/tiger/__init__.py b/rectools/semantic/tiger/__init__.py
new file mode 100644
index 00000000..a962e05c
--- /dev/null
+++ b/rectools/semantic/tiger/__init__.py
@@ -0,0 +1,11 @@
+"""TIGER generative recommender model with semantic ID inputs"""
+
+from .lightning import TIGERLightning
+from .model import TIGERModel
+from .module import TIGERNet
+
+__all__ = [
+ "TIGERLightning",
+ "TIGERModel",
+ "TIGERNet",
+]
diff --git a/rectools/semantic/tiger/lightning.py b/rectools/semantic/tiger/lightning.py
new file mode 100644
index 00000000..4aa57e7d
--- /dev/null
+++ b/rectools/semantic/tiger/lightning.py
@@ -0,0 +1,233 @@
+import typing as tp
+
+import numpy as np
+import pytorch_lightning as pl
+import torch
+from pytorch_lightning.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
+
+from rectools.semantic import LRScheduleType, OptimizerType
+from rectools.semantic.metrics import coverage_k, gini_k
+from rectools.semantic.tokenizer import SIDTokenizer
+
+from .loss import compute_tiger_loss
+from .module import TIGERNet
+
+
+class TIGERLightning(pl.LightningModule):
+ """A ``pytorch_lightninig` wrapper around TIGERNet which used used for
+ training, validation and testing of the model.
+
+ Parameters
+ ----------
+ model : TIGERNet
+ Model to train
+ tokenizer : SIDTokenizer
+ Semantic ID tokenizer
+ optimizer_name : OptimizerType, optional
+ Which optimizer to use, by default "adamw"
+ beam_size : int, optional
+ Beam size for model inference, by default 20
+ top_k : int, optional
+ Used for @k metrics like HR@k or MRR@k, by default 10
+ lr : float, optional
+ Model learning rate, by default 5e-4
+ weight_decay : float, optional
+ Optimizer weight decay, by default 1e-4
+ warmup_steps : int | float, optional
+ Learning rate warm-up steps, by default 300
+ lr_schedule : tp.Optional[LRScheduleType], optional
+ Learning rate schedule to use. Constant LR if set to None, by default None
+ max_iters : int, optional
+ Maximum number of iterations for training, by default 1
+ num_items : tp.Optional[int], optional
+ Number of unique items, by default None
+ """
+
+ def __init__(
+ self,
+ model: TIGERNet,
+ tokenizer: SIDTokenizer,
+ optimizer_name: OptimizerType = "adamw",
+ beam_size: int = 20,
+ top_k: int = 10,
+ lr: float = 5e-4,
+ weight_decay: float = 1e-4,
+ warmup_steps: int | float = 300,
+ lr_schedule: tp.Optional[LRScheduleType] = None,
+ max_iters: int = 1,
+ num_items: tp.Optional[int] = None,
+ decode_rng: tp.Optional[np.random.RandomState] = None,
+ ):
+ super().__init__()
+
+ self.model = model
+ self.tokenizer = tokenizer
+ self.beam_size = beam_size
+ self.top_k = top_k
+ self.num_items = num_items
+
+ self.optimizer_name = optimizer_name
+ self.lr = lr
+ self.lr_schedule = lr_schedule
+ self.weight_decay = weight_decay
+ self.max_iters = max_iters
+ self.decode_rng = decode_rng
+
+ self._all_topk_items: tp.List[tp.List[int]] = []
+
+ if isinstance(warmup_steps, float) and 0 < warmup_steps < 1:
+ self.warmup_steps = round(max_iters * warmup_steps)
+ else:
+ self.warmup_steps = int(warmup_steps)
+
+ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTPUT:
+ input_ids = batch["input_ids"]
+ dec_input = batch["dec_input"]
+ labels = batch["labels"]
+
+ enc_padding_mask = input_ids == TIGERNet.PAD_TOKEN_ID
+
+ logits = self.model(input_ids, dec_input, enc_padding_mask)
+ loss = compute_tiger_loss(logits, labels)
+
+ self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
+ return loss
+
+ def _valtest_step( # pylint: disable=too-many-locals
+ self, batch: tp.Dict[str, torch.Tensor], batch_idx: int, mode: tp.Literal["val", "test"]
+ ) -> STEP_OUTPUT:
+ prefix = "" if mode == "test" else "val_"
+ input_sids = batch["input_ids"]
+ labels = batch["labels"].cpu().tolist()
+
+ metrics = {}
+ all_topk_items = []
+
+ enc_padding_mask = input_sids == TIGERNet.PAD_TOKEN_ID
+
+ codes, _scores = self.model.generate(input_sids, enc_padding_mask=enc_padding_mask, beam_size=self.beam_size)
+ batch_size, num_beams = codes.size(0), codes.size(1)
+
+ # Batch-decode: flatten (batch, beam, depth) -> list of tuples, decode once
+ codes_cpu = codes.cpu().tolist()
+ all_sids = [tuple(codes_cpu[i][j]) for i in range(batch_size) for j in range(num_beams)]
+ all_decoded: tp.List[tp.Optional[int]] = self.tokenizer.decode( # type: ignore[assignment]
+ all_sids,
+ rng=self.decode_rng,
+ )
+
+ # Reshape decoded items back to (batch, beam)
+ decoded_grid = [all_decoded[i * num_beams : (i + 1) * num_beams] for i in range(batch_size)]
+
+ # Per-row deduplication + rank finding (small Python loop)
+ ranks = np.zeros(batch_size, dtype=np.int64)
+ valid_sids = 0
+ for i in range(batch_size):
+ row = decoded_grid[i]
+ target = labels[i]
+ seen = set()
+ unique_items = []
+ rank_pos = 0
+ found = False
+ for item_id in row:
+ if item_id is not None and item_id not in seen:
+ valid_sids += 1
+ rank_pos += 1
+ seen.add(item_id)
+ unique_items.append(item_id)
+ if not found and item_id == target:
+ ranks[i] = rank_pos
+ found = True
+ all_topk_items.append(unique_items)
+
+ self._all_topk_items.extend(all_topk_items)
+
+ # Vectorized metric computation
+ in_top_k = (ranks > 0) & (ranks <= self.top_k)
+ top_k_ranks = ranks[in_top_k]
+
+ hit_sum = in_top_k.sum()
+
+ metrics[f"{prefix}Hit@{self.top_k}"] = hit_sum / batch_size
+ metrics[f"{prefix}NDCG@{self.top_k}"] = (
+ float((1.0 / np.log2(top_k_ranks + 1)).sum()) if hit_sum > 0 else 0.0
+ ) / batch_size
+ metrics[f"{prefix}MRR@{self.top_k}"] = (float((1.0 / top_k_ranks).sum()) if hit_sum > 0 else 0.0) / batch_size
+ self.log_dict(metrics, on_epoch=True, on_step=False, prog_bar=True)
+ return metrics
+
+ def _on_valtest_epoch_end(self, mode: tp.Literal["val", "test"]) -> STEP_OUTPUT:
+ prefix = "" if mode == "test" else "val_"
+ metrics = {}
+ metrics[f"{prefix}Gini@{self.top_k}"] = gini_k(self._all_topk_items, self.top_k)
+ if self.num_items is not None:
+ metrics[f"{prefix}Coverage@{self.top_k}"] = coverage_k(self._all_topk_items, self.top_k, self.num_items)
+ self.log_dict(metrics, on_epoch=True, on_step=False, prog_bar=True)
+ return metrics
+
+ def on_validation_epoch_start(self) -> None:
+ self._all_topk_items = []
+
+ def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTPUT:
+ return self._valtest_step(batch, batch_idx, mode="val")
+
+ def on_validation_epoch_end(self) -> None:
+ self._on_valtest_epoch_end(mode="val")
+
+ def on_test_epoch_start(self) -> None:
+ self._all_topk_items = []
+
+ def test_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTPUT:
+ return self._valtest_step(batch, batch_idx, mode="test")
+
+ def on_test_epoch_end(self) -> None:
+ self._on_valtest_epoch_end(mode="test")
+
+ def configure_optimizers(self) -> OptimizerLRScheduler:
+ optimizer: torch.optim.Optimizer
+ if self.optimizer_name == "adamw":
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
+ elif self.optimizer_name == "adagrad":
+ optimizer = torch.optim.Adagrad(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
+ elif self.optimizer_name == "adam":
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
+ else:
+ raise ValueError(f"Unknown optimizer: {self.optimizer_name}")
+
+ if self.lr_schedule is None:
+ return optimizer
+
+ schedulers: tp.List[torch.optim.lr_scheduler.LRScheduler] = []
+ milestones: tp.List[int] = []
+ if self.warmup_steps > 0:
+ schedulers.append(
+ torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=self.warmup_steps)
+ )
+ milestones = [self.warmup_steps]
+ if self.lr_schedule == "cosine":
+ t_0 = max(1, self.max_iters)
+ schedulers.append(
+ torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
+ optimizer,
+ T_0=t_0,
+ )
+ )
+ elif self.lr_schedule == "linear":
+ schedulers.append(torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.0))
+ else:
+ schedulers.append(
+ torch.optim.lr_scheduler.ConstantLR(
+ optimizer,
+ factor=1.0,
+ )
+ )
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
+ optimizer,
+ schedulers,
+ milestones=milestones,
+ )
+
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": {"scheduler": scheduler, "interval": "step"},
+ }
diff --git a/rectools/semantic/tiger/loss.py b/rectools/semantic/tiger/loss.py
new file mode 100644
index 00000000..13664612
--- /dev/null
+++ b/rectools/semantic/tiger/loss.py
@@ -0,0 +1,28 @@
+import typing as tp
+
+import torch
+from torch.nn.functional import cross_entropy
+
+
+def compute_tiger_loss(logits: tp.List[torch.Tensor], labels: torch.Tensor) -> torch.Tensor:
+ """
+ Compute cross-entropy loss for TIGER single-target prediction.
+
+ Parameters
+ ----------
+ logits : list of Tensor, each [B, codebook_size_d]
+ One tensor per codebook level.
+ labels : LongTensor [B, sid_len]
+ Raw SID codes (0-based, no offset) for the target item.
+ """
+ total_loss = torch.tensor(0.0)
+ sid_len = labels.shape[-1]
+
+ for d in range(sid_len):
+ total_loss = total_loss + cross_entropy(
+ logits[d], # [B, codebook_size_d]
+ labels[:, d], # [B]
+ ignore_index=-100,
+ )
+
+ return total_loss / sid_len
diff --git a/rectools/semantic/tiger/model.py b/rectools/semantic/tiger/model.py
new file mode 100644
index 00000000..585518d4
--- /dev/null
+++ b/rectools/semantic/tiger/model.py
@@ -0,0 +1,396 @@
+import json
+import os
+import typing as tp
+
+import numpy as np
+import pandas as pd
+import pytorch_lightning as pl
+import torch
+from pytorch_lightning.callbacks import LearningRateMonitor, ModelSummary
+from pytorch_lightning.callbacks.early_stopping import EarlyStopping
+from torch.utils.data import DataLoader
+from tqdm.auto import trange
+
+from rectools.semantic import LRScheduleType, OptimizerType
+from rectools.semantic.data_handling import PaddingCollateFn, TIGERDataset
+from rectools.semantic.tokenizer import SIDTokenizer
+
+from .lightning import TIGERLightning
+from .module import TIGERNet
+
+
+class TIGERModel: # pylint: disable=too-many-instance-attributes
+ """TIGER generative recommender model with semantic IDs.
+
+ Given a pretrained SIDTokenizer, trains a TIGER generative recommender model.
+
+ Parameters
+ ----------
+ tokenizer : SIDTokenizer
+ Pretrained semantic ID tokenizer.
+ """
+
+ # pylint: disable=too-many-arguments,too-many-locals
+ def __init__(
+ self,
+ tokenizer: SIDTokenizer,
+ # TIGER model hyperparams
+ hidden_units: int = 256,
+ num_blocks: int = 2,
+ num_heads: int = 1,
+ dropout_rate: float = 0.1,
+ max_length: int = 200,
+ ff_dim: tp.Optional[int] = None,
+ d_kv: tp.Optional[int] = None,
+ # Training hyperparams
+ optimizer: OptimizerType = "adamw",
+ lr: float = 1e-3,
+ lr_schedule: tp.Optional[LRScheduleType] = None,
+ warmup_steps: int = 300,
+ weight_decay: float = 1e-4,
+ max_epochs: int = 100,
+ patience: tp.Optional[int] = 10,
+ batch_size: int = 256,
+ eval_batch_size: int = 64,
+ num_workers: int = 4,
+ beam_size: int = 20,
+ top_k: int = 10,
+ train_only_last: bool = True,
+ random_seed: tp.Optional[int] = None,
+ # General
+ device: tp.Optional[str] = None,
+ ):
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+
+ self.tokenizer = tokenizer
+ self.codebook_sizes = list(tokenizer.quantizer.codebook_sizes)
+
+ # Training hyperparams
+ self.optimizer = optimizer
+ self.lr = lr
+ self.lr_schedule = lr_schedule
+ self.warmup_steps = warmup_steps
+ self.weight_decay = weight_decay
+ self.max_epochs = max_epochs
+ self.patience = patience
+ self.batch_size = batch_size
+ self.eval_batch_size = eval_batch_size
+ self.num_workers = num_workers
+ self.beam_size = beam_size
+ self.top_k = top_k
+ self.train_only_last = train_only_last
+ self.random_seed = random_seed
+ self._decode_rng = np.random.RandomState(random_seed) if random_seed is not None else None
+
+ # TIGER model hyperparams
+ self.hidden_units = hidden_units
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.dropout_rate = dropout_rate
+ self.max_length = max_length
+ self.ff_dim = ff_dim
+ self.d_kv = d_kv
+
+ self.model = TIGERNet(
+ codebook_sizes=self.codebook_sizes,
+ hidden_units=hidden_units,
+ num_blocks=num_blocks,
+ num_heads=num_heads,
+ dropout_rate=dropout_rate,
+ max_length=max_length,
+ ff_dim=ff_dim,
+ d_kv=d_kv,
+ ).to(self.device)
+
+ def _tiger_dataset_kwargs(self) -> dict:
+ return {
+ "tokenizer": self.tokenizer,
+ "codebook_sizes": self.codebook_sizes,
+ "max_length": self.max_length,
+ "codeword_offset": TIGERNet.CODEWORD_OFFSET,
+ "bos_token_id": TIGERNet.BOS_TOKEN_ID,
+ }
+
+ def fit(
+ self,
+ train_df: pd.DataFrame,
+ val_df: pd.DataFrame,
+ monitor: tp.Literal["Hit", "MRR", "NDCG"] = "Hit",
+ ) -> None:
+ """Train the TIGER model.
+
+ Parameters
+ ----------
+ train_df : pd.DataFrame
+ Training interactions (user_id, item_id, optionally timestamp).
+ val_df : pd.DataFrame
+ Validation interactions (same schema).
+
+ Notes
+ -----
+ If multiple items share the same Semantic ID, decoding during
+ validation remains stochastic. Set ``random_seed`` at model
+ initialization to make this collision resolution reproducible.
+ """
+ num_items = len(set(train_df["item_id"]).union(val_df["item_id"]))
+
+ ds_kwargs = self._tiger_dataset_kwargs()
+
+ train_dataset = TIGERDataset(
+ interactions=train_df,
+ only_last=self.train_only_last,
+ eval_mode=False,
+ **ds_kwargs,
+ )
+
+ val_dataset = TIGERDataset(
+ interactions=val_df,
+ only_last=True,
+ eval_mode=True,
+ **ds_kwargs,
+ )
+
+ train_loader = DataLoader(
+ train_dataset,
+ batch_size=self.batch_size,
+ shuffle=True,
+ num_workers=self.num_workers,
+ collate_fn=PaddingCollateFn(padding_value=TIGERNet.PAD_TOKEN_ID),
+ pin_memory=True,
+ )
+
+ val_loader = DataLoader(
+ val_dataset,
+ batch_size=self.eval_batch_size,
+ num_workers=self.num_workers,
+ collate_fn=PaddingCollateFn(padding_value=TIGERNet.PAD_TOKEN_ID),
+ pin_memory=True,
+ )
+
+ lightning_model = TIGERLightning(
+ model=self.model,
+ tokenizer=self.tokenizer,
+ optimizer_name=self.optimizer,
+ lr=self.lr,
+ lr_schedule=self.lr_schedule,
+ warmup_steps=self.warmup_steps,
+ max_iters=self.max_epochs * len(train_loader) - self.warmup_steps,
+ weight_decay=self.weight_decay,
+ num_items=num_items,
+ beam_size=self.beam_size,
+ top_k=self.top_k,
+ )
+
+ callbacks = [LearningRateMonitor(logging_interval="step"), ModelSummary()]
+ if self.patience is not None:
+ callbacks.append(
+ EarlyStopping(
+ monitor=f"val_{monitor}@{self.top_k}",
+ patience=self.patience,
+ mode="max",
+ )
+ )
+
+ trainer = pl.Trainer(max_epochs=self.max_epochs, callbacks=callbacks)
+ trainer.fit(
+ lightning_model,
+ train_dataloaders=train_loader,
+ val_dataloaders=val_loader,
+ )
+
+ self.model = lightning_model.model.to(self.device)
+
+ @torch.no_grad()
+ def predict(
+ self,
+ interactions: pd.DataFrame,
+ top_k: int = 10,
+ ) -> pd.DataFrame:
+ """Generate top-k recommendations for each user.
+
+ Parameters
+ ----------
+ interactions : pd.DataFrame
+ User interaction histories (user_id, item_id, optionally timestamp).
+ top_k : int
+ Number of recommendations per user.
+
+ Returns
+ -------
+ pd.DataFrame
+ Recommendations with columns: user_id, item_id, score, rank.
+
+ Notes
+ -----
+ If multiple items share the same Semantic ID, recommendation
+ deduplication depends on stochastic SID decoding. Set
+ ``random_seed`` at model initialization to make outputs
+ reproducible.
+ """
+ self.model.eval()
+
+ df = interactions
+ if "timestamp" in df.columns:
+ df = df.sort_values(["user_id", "timestamp"])
+
+ user_sequences = df.groupby("user_id")["item_id"].agg(list)
+ user_ids = user_sequences.index.tolist()
+ sequences = user_sequences.tolist()
+
+ offsets = np.cumsum([TIGERNet.CODEWORD_OFFSET] + self.codebook_sizes)[:-1]
+
+ enc_tokens_list = []
+ for seq in sequences:
+ seq = seq[-self.max_length :]
+ sids: tp.List[tp.Tuple[int]] = self.tokenizer.tokenize(seq) # type: ignore[assignment]
+ tokens: tp.List[int] = []
+ for sid in sids:
+ tokens.extend(code + offsets[d] for d, code in enumerate(sid))
+ enc_tokens_list.append(torch.tensor(tokens, dtype=torch.long))
+
+ rows: tp.List[tp.Tuple] = []
+ n_users = len(enc_tokens_list)
+
+ for start in trange(0, n_users, self.eval_batch_size, desc="Generating predictions"):
+ end = min(start + self.eval_batch_size, n_users)
+ batch_tokens = enc_tokens_list[start:end]
+ batch_user_ids = user_ids[start:end]
+
+ batch_max_len = max(t.size(0) for t in batch_tokens)
+ enc_input = torch.full(
+ (len(batch_tokens), batch_max_len),
+ TIGERNet.PAD_TOKEN_ID,
+ dtype=torch.long,
+ device=self.device,
+ )
+ for i, t in enumerate(batch_tokens):
+ enc_input[i, : t.size(0)] = t.to(self.device)
+
+ enc_padding_mask = enc_input == TIGERNet.PAD_TOKEN_ID
+
+ codes, scores = self.model.generate(enc_input, enc_padding_mask=enc_padding_mask, beam_size=top_k)
+
+ batch_size, num_beams = codes.size(0), codes.size(1)
+ codes_cpu = codes.cpu().tolist()
+ scores_cpu = scores.cpu()
+
+ all_sids = [tuple(codes_cpu[i][j]) for i in range(batch_size) for j in range(num_beams)]
+ all_decoded: tp.List[tp.Optional[int]] = self.tokenizer.decode( # type: ignore[assignment]
+ all_sids,
+ rng=self._decode_rng,
+ )
+
+ for i in range(batch_size):
+ uid = batch_user_ids[i]
+ row_decoded = all_decoded[i * num_beams : (i + 1) * num_beams]
+ row_scores = scores_cpu[i]
+ seen: tp.Set[int] = set()
+ rank = 0
+ for j, item_id in enumerate(row_decoded):
+ if item_id is not None and item_id not in seen:
+ seen.add(item_id)
+ rank += 1
+ rows.append((uid, item_id, row_scores[j].item(), rank))
+ if rank >= top_k:
+ break
+
+ return pd.DataFrame(rows, columns=["user_id", "item_id", "score", "rank"])
+
+ def evaluate(
+ self,
+ test_df: pd.DataFrame,
+ top_k: tp.Optional[int] = None,
+ ) -> tp.Mapping[str, float]:
+ """Evaluate the model on a test set.
+
+ Parameters
+ ----------
+ test_df : pd.DataFrame
+ Test interactions (user_id, item_id, optionally timestamp).
+ For each user, the last item is treated as the ground-truth target.
+ top_k : int, optional
+ Number of recommendations to generate. Defaults to ``self.top_k``.
+
+ Returns
+ -------
+ dict
+ Metric name -> value (Hit@k, NDCG@k, MRR@k, etc.).
+
+ Notes
+ -----
+ If multiple items share the same Semantic ID, decoding during
+ evaluation remains stochastic. Set ``random_seed`` at model
+ initialization to make metric computation reproducible.
+ """
+ if top_k is None:
+ top_k = self.top_k
+
+ num_items = test_df["item_id"].nunique()
+
+ test_dataset = TIGERDataset(
+ interactions=test_df,
+ eval_mode=True,
+ only_last=True,
+ **self._tiger_dataset_kwargs(),
+ )
+
+ test_loader = DataLoader(
+ test_dataset,
+ batch_size=self.eval_batch_size,
+ num_workers=self.num_workers,
+ collate_fn=PaddingCollateFn(padding_value=TIGERNet.PAD_TOKEN_ID),
+ pin_memory=True,
+ )
+
+ lightning_model = TIGERLightning(
+ model=self.model,
+ tokenizer=self.tokenizer,
+ num_items=num_items,
+ beam_size=self.beam_size,
+ top_k=top_k,
+ decode_rng=self._decode_rng,
+ )
+
+ trainer = pl.Trainer()
+ results = trainer.test(lightning_model, dataloaders=test_loader)[0]
+ self.model = self.model.to(self.device)
+ return results
+
+ def save(self, directory: str) -> None:
+ """Save the model to a directory.
+
+ Writes ``tokenizer.pt``, ``model.pt``, and ``config.json``.
+ """
+ os.makedirs(directory, exist_ok=True)
+ self.tokenizer.save(os.path.join(directory, "tokenizer.pt"))
+ torch.save(self.model.state_dict(), os.path.join(directory, "model.pt"))
+
+ config = {
+ "hidden_units": self.hidden_units,
+ "num_blocks": self.num_blocks,
+ "num_heads": self.num_heads,
+ "dropout_rate": self.dropout_rate,
+ "max_length": self.max_length,
+ "ff_dim": self.ff_dim,
+ "d_kv": self.d_kv,
+ "random_seed": self.random_seed,
+ }
+ with open(os.path.join(directory, "config.json"), "w", encoding="utf-8") as f:
+ json.dump(config, f, indent=2)
+
+ @classmethod
+ def load(cls, directory: str, device: tp.Optional[str] = None) -> "TIGERModel":
+ """Load a saved TIGERModel from a directory."""
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+
+ with open(os.path.join(directory, "config.json"), encoding="utf-8") as f:
+ config = json.load(f)
+
+ tokenizer = SIDTokenizer.load(os.path.join(directory, "tokenizer.pt"), device=device)
+
+ tiger_model = cls(tokenizer=tokenizer, device=device, **config)
+
+ state_dict = torch.load(os.path.join(directory, "model.pt"), map_location=device)
+ tiger_model.model.load_state_dict(state_dict)
+
+ return tiger_model
diff --git a/rectools/semantic/tiger/module.py b/rectools/semantic/tiger/module.py
new file mode 100644
index 00000000..3efb95b0
--- /dev/null
+++ b/rectools/semantic/tiger/module.py
@@ -0,0 +1,542 @@
+import math
+import typing as tp
+
+import torch
+from torch import nn
+from torch.nn.functional import log_softmax
+
+from rectools.semantic.modules import T5DecoderLayer, T5EncoderLayer, T5RMSNorm, T5RelativePositionBias
+
+
+class T5Encoder(nn.Module):
+ """
+ T5-style encoder stack with shared relative position bias.
+
+ Replaces ``nn.TransformerEncoder`` for the T5 code path so that a
+ single ``T5RelativePositionBias`` is computed once and shared
+ across all layers.
+
+ Parameters
+ ----------
+ d_model : int
+ Model dimension.
+ num_heads : int
+ Number of attention heads.
+ d_kv : int
+ Per-head key/value (and query) dimension.
+ num_layers : int
+ Number of encoder layers.
+ dim_feedforward : int
+ FFN inner dimension.
+ dropout : float
+ Dropout probability.
+ activation : str
+ FFN activation.
+ relative_position_num_buckets : int
+ Number of buckets for relative position bias.
+ relative_position_max_distance : int
+ Max distance for relative position bucketing.
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ d_kv: int,
+ num_layers: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: str = "relu",
+ relative_position_num_buckets: int = 32,
+ relative_position_max_distance: int = 128,
+ ) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [
+ T5EncoderLayer(
+ d_model=d_model,
+ num_heads=num_heads,
+ d_kv=d_kv,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ activation=activation,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.final_norm = T5RMSNorm(d_model)
+ self.position_bias = T5RelativePositionBias(
+ num_heads=num_heads,
+ bidirectional=True,
+ num_buckets=relative_position_num_buckets,
+ max_distance=relative_position_max_distance,
+ )
+
+ def forward(
+ self,
+ src: torch.Tensor,
+ src_key_padding_mask: tp.Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ pos_bias = self.position_bias(query_len=src.size(1), key_len=src.size(1), device=src.device)
+ x = src
+ for layer in self.layers:
+ x = layer(x, src_key_padding_mask=src_key_padding_mask, position_bias=pos_bias)
+ return self.final_norm(x)
+
+
+class T5Decoder(nn.Module):
+ """
+ T5-style decoder stack with shared relative position bias.
+
+ Replaces ``nn.TransformerDecoder`` for the T5 code path so that a
+ single ``T5RelativePositionBias`` (causal / unidirectional) is
+ computed once and shared across all layers.
+
+ Parameters
+ ----------
+ d_model : int
+ Model dimension.
+ num_heads : int
+ Number of attention heads.
+ d_kv : int
+ Per-head key/value (and query) dimension.
+ num_layers : int
+ Number of decoder layers.
+ dim_feedforward : int
+ FFN inner dimension.
+ dropout : float
+ Dropout probability.
+ activation : str
+ FFN activation.
+ relative_position_num_buckets : int
+ Number of buckets for relative position bias.
+ relative_position_max_distance : int
+ Max distance for relative position bucketing.
+ """
+
+ def __init__(
+ self,
+ d_model: int,
+ num_heads: int,
+ d_kv: int,
+ num_layers: int,
+ dim_feedforward: int = 2048,
+ dropout: float = 0.1,
+ activation: str = "relu",
+ relative_position_num_buckets: int = 32,
+ relative_position_max_distance: int = 128,
+ ) -> None:
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [
+ T5DecoderLayer(
+ d_model=d_model,
+ num_heads=num_heads,
+ d_kv=d_kv,
+ dim_feedforward=dim_feedforward,
+ dropout=dropout,
+ activation=activation,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+ self.final_norm = T5RMSNorm(d_model)
+ self.position_bias = T5RelativePositionBias(
+ num_heads=num_heads,
+ bidirectional=False,
+ num_buckets=relative_position_num_buckets,
+ max_distance=relative_position_max_distance,
+ )
+
+ def forward(
+ self,
+ tgt: torch.Tensor,
+ memory: torch.Tensor,
+ tgt_mask: tp.Optional[torch.Tensor] = None,
+ memory_key_padding_mask: tp.Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ pos_bias = self.position_bias(query_len=tgt.size(1), key_len=tgt.size(1), device=tgt.device)
+ x = tgt
+ for layer in self.layers:
+ x = layer(
+ x,
+ memory,
+ tgt_mask=tgt_mask,
+ memory_key_padding_mask=memory_key_padding_mask,
+ position_bias=pos_bias,
+ )
+ return self.final_norm(x)
+
+
+class TIGERNet(nn.Module): # pylint: disable=too-many-instance-attributes
+ """
+ Generative retrieval model for sequential recommendation.
+
+ Items are represented as Semantic IDs -- tuples of discrete codewords
+ produced by RQ-VAE. The model is a Transformer encoder-decoder:
+ * Encoder input : flattened Semantic ID tokens of the user history.
+ * Decoder output: Semantic ID tokens of the next item, predicted
+ autoregressively one codeword at a time.
+
+ Special tokens
+ --------------
+ 0 -- padding token (in both encoder and decoder)
+ 1 -- BOS (beginning-of-sequence) token fed to the decoder at step 0
+
+ All codeword tokens are offset by +2 so that the vocabulary is:
+ {0: PAD, 1: BOS, 2 .. 2+sum(codebook_sizes)-1 : codewords}
+ Each codebook level gets its own contiguous slice of the vocabulary.
+
+ Parameters
+ ----------
+ codebook_sizes : list[int]
+ Number of codes per RQ-VAE level (e.g. [256, 256, 256]).
+ hidden_units : int
+ Dimensionality of the Transformer hidden states.
+ num_blocks : int
+ Number of Transformer layers in **each** of the encoder and decoder.
+ num_heads : int
+ Number of attention heads.
+ dropout_rate : float
+ Dropout probability.
+ max_length : int
+ Maximum number of items in the encoder history.
+ initializer_range : float
+ Std for weight initialisation.
+ ff_dim : int, optional
+ Feed-forward inner dimension. Defaults to ``4 * hidden_units``.
+ d_kv : int, optional
+ Per-head key/value (and query) projected dimension for T5-style
+ attention. When set, the encoder and decoder use T5-style layers
+ with RMSNorm and relative position bias (no absolute positional
+ embeddings). When ``None`` (default), standard
+ ``nn.TransformerEncoderLayer`` / ``nn.TransformerDecoderLayer``
+ are used with learned absolute positional embeddings.
+ relative_position_num_buckets : int
+ Number of buckets for T5 relative position bias (only used when
+ ``d_kv`` is set).
+ relative_position_max_distance : int
+ Max distance for T5 relative position bucketing (only used when
+ ``d_kv`` is set).
+ """
+
+ PAD_TOKEN_ID = 0
+ BOS_TOKEN_ID = 1
+ CODEWORD_OFFSET = 2 # first codeword token id
+
+ def __init__(
+ self,
+ codebook_sizes: tp.List[int],
+ hidden_units: int = 256,
+ num_blocks: int = 2,
+ num_heads: int = 1,
+ dropout_rate: float = 0.1,
+ max_length: int = 256,
+ initializer_range: float = 0.02,
+ ff_dim: tp.Optional[int] = None,
+ d_kv: tp.Optional[int] = None,
+ relative_position_num_buckets: int = 32,
+ relative_position_max_distance: int = 128,
+ ) -> None:
+ super().__init__()
+
+ self.codebook_sizes = codebook_sizes
+ self.sid_len = len(codebook_sizes) # number of codewords per item
+ self.hidden_units = hidden_units
+ self.num_blocks = num_blocks
+ self.num_heads = num_heads
+ self.dropout_rate = dropout_rate
+ self.max_length = max_length
+ self.initializer_range = initializer_range
+ self.d_kv = d_kv
+
+ # Total vocabulary: PAD + BOS + all codeword tokens
+ self.vocab_size = self.CODEWORD_OFFSET + sum(codebook_sizes)
+
+ self.enc_max_len = max_length * self.sid_len
+ self.dec_max_len = self.sid_len
+
+ if ff_dim is None:
+ ff_dim = hidden_units * 4
+
+ self.token_emb = nn.Embedding(self.vocab_size, hidden_units, padding_idx=self.PAD_TOKEN_ID)
+
+ self.enc_dropout = nn.Dropout(dropout_rate)
+ self.dec_dropout = nn.Dropout(dropout_rate)
+
+ self.encoder: nn.Module
+ self.decoder: nn.Module
+ if d_kv is not None:
+ # T5 path: relative position bias instead of absolute embeddings.
+ self.enc_pos_emb = None
+ self.dec_pos_emb = None
+
+ self.encoder = T5Encoder(
+ d_model=hidden_units,
+ num_heads=num_heads,
+ d_kv=d_kv,
+ num_layers=num_blocks,
+ dim_feedforward=ff_dim,
+ dropout=dropout_rate,
+ activation="relu",
+ relative_position_num_buckets=relative_position_num_buckets,
+ relative_position_max_distance=relative_position_max_distance,
+ )
+ self.decoder = T5Decoder(
+ d_model=hidden_units,
+ num_heads=num_heads,
+ d_kv=d_kv,
+ num_layers=num_blocks,
+ dim_feedforward=ff_dim,
+ dropout=dropout_rate,
+ activation="relu",
+ relative_position_num_buckets=relative_position_num_buckets,
+ relative_position_max_distance=relative_position_max_distance,
+ )
+ else:
+ # Standard path: learned absolute positional embeddings.
+ self.enc_pos_emb = nn.Embedding(self.enc_max_len, hidden_units)
+ self.dec_pos_emb = nn.Embedding(self.dec_max_len + 1, hidden_units) # +1 for BOS step
+
+ encoder_layer = nn.TransformerEncoderLayer(
+ d_model=hidden_units,
+ nhead=num_heads,
+ dim_feedforward=ff_dim,
+ dropout=dropout_rate,
+ activation="relu",
+ batch_first=True,
+ norm_first=True,
+ )
+ decoder_layer = nn.TransformerDecoderLayer(
+ d_model=hidden_units,
+ nhead=num_heads,
+ dim_feedforward=ff_dim,
+ dropout=dropout_rate,
+ activation="relu",
+ batch_first=True,
+ norm_first=True,
+ )
+
+ self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_blocks, norm=nn.LayerNorm(hidden_units))
+ self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_blocks, norm=nn.LayerNorm(hidden_units))
+
+ self.output_heads = nn.ModuleList()
+ for cb_size in codebook_sizes:
+ self.output_heads.append(nn.Linear(hidden_units, cb_size))
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, module: nn.Module) -> None:
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=self.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, T5RMSNorm):
+ module.weight.data.fill_(1.0)
+
+ def _codebook_offsets(self, device: torch.device) -> torch.Tensor:
+ """Return tensor [sid_len] with the per-level offset into the vocab."""
+ offsets = [self.CODEWORD_OFFSET]
+ for s in self.codebook_sizes[:-1]:
+ offsets.append(offsets[-1] + s)
+ return torch.tensor(offsets, device=device)
+
+ def sid_to_tokens(self, sid: torch.Tensor) -> torch.Tensor:
+ offsets = self._codebook_offsets(sid.device)
+ return sid + offsets
+
+ def tokens_to_sid(self, tokens: torch.Tensor) -> torch.Tensor:
+ offsets = self._codebook_offsets(tokens.device)
+ return tokens - offsets
+
+ def encode(
+ self,
+ enc_input: torch.Tensor,
+ enc_padding_mask: tp.Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Encode the user history.
+
+ Parameters
+ ----------
+ enc_input : LongTensor [B, T_enc]
+ Vocabulary token ids for the encoder (already offset).
+ enc_padding_mask : BoolTensor [B, T_enc], optional
+ True where the token is padding.
+
+ Returns
+ -------
+ memory : Tensor [B, T_enc, H]
+ """
+ x = self.token_emb(enc_input)
+ if self.enc_pos_emb is not None:
+ positions = torch.arange(enc_input.size(1), device=enc_input.device).unsqueeze(0)
+ x = x + self.enc_pos_emb(positions)
+ x = self.enc_dropout(x)
+
+ memory = self.encoder(x, src_key_padding_mask=enc_padding_mask)
+ return memory
+
+ def decode(
+ self,
+ dec_input: torch.Tensor,
+ memory: torch.Tensor,
+ enc_padding_mask: tp.Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Decode with teacher forcing.
+
+ Parameters
+ ----------
+ dec_input : LongTensor [B, T_dec]
+ Decoder input token ids (BOS followed by target tokens shifted right).
+ memory : Tensor [B, T_enc, H]
+ Encoder output.
+ enc_padding_mask : BoolTensor [B, T_enc], optional
+ Padding mask for cross-attention.
+
+ Returns
+ -------
+ Tensor [B, T_dec, H]
+ """
+ x = self.token_emb(dec_input)
+ if self.dec_pos_emb is not None:
+ x = x * math.sqrt(self.hidden_units)
+ positions = torch.arange(dec_input.size(1), device=dec_input.device).unsqueeze(0)
+ x = x + self.dec_pos_emb(positions)
+ x = self.dec_dropout(x)
+
+ # causal mask so position i can only attend to positions <= i
+ tgt_len = dec_input.size(1)
+ causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len, device=dec_input.device)
+
+ out = self.decoder(
+ x,
+ memory,
+ tgt_mask=causal_mask,
+ memory_key_padding_mask=enc_padding_mask,
+ )
+ return out
+
+ def forward(
+ self,
+ enc_input: torch.Tensor,
+ dec_input: torch.Tensor,
+ enc_padding_mask: tp.Optional[torch.Tensor] = None,
+ ) -> tp.List[torch.Tensor]:
+ memory = self.encode(enc_input, enc_padding_mask)
+ hidden = self.decode(dec_input, memory, enc_padding_mask) # [B, sid_len, H]
+
+ logits = []
+ for d in range(self.sid_len):
+ logits.append(self.output_heads[d](hidden[:, d, :])) # [B, cb_size_d]
+
+ return logits
+
+ @torch.no_grad()
+ def generate_greedy(
+ self,
+ enc_input: torch.Tensor,
+ enc_padding_mask: tp.Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ memory = self.encode(enc_input, enc_padding_mask)
+ B = enc_input.size(0)
+ device = enc_input.device
+
+ # Start with BOS token
+ generated_tokens = [torch.full((B,), self.BOS_TOKEN_ID, dtype=torch.long, device=device)]
+ offsets = self._codebook_offsets(device)
+ predicted_codes = []
+
+ for d in range(self.sid_len):
+ dec_input = torch.stack(generated_tokens, dim=1) # [B, d+1]
+ hidden = self.decode(dec_input, memory, enc_padding_mask) # [B, d+1, H]
+ logits_d = self.output_heads[d](hidden[:, -1, :]) # [B, cb_size_d]
+ code_d = logits_d.argmax(dim=-1) # [B]
+ predicted_codes.append(code_d)
+
+ # Convert raw code to vocab token for the next decoder step
+ token_d = code_d + offsets[d]
+ generated_tokens.append(token_d)
+
+ return torch.stack(predicted_codes, dim=1) # [B, sid_len]
+
+ @torch.no_grad()
+ def generate( # pylint: disable=too-many-locals
+ self,
+ enc_input: torch.Tensor,
+ enc_padding_mask: tp.Optional[torch.Tensor] = None,
+ beam_size: int = 10,
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Beam search decoding (fully batched).
+
+ Returns
+ -------
+ all_sids : LongTensor [B, beam_size, sid_len]
+ Top-``beam_size`` predicted Semantic IDs (raw codewords, 0-based).
+ all_scores : Tensor [B, beam_size]
+ Log-probability scores for each beam.
+ """
+ memory = self.encode(enc_input, enc_padding_mask)
+ B = enc_input.size(0)
+ device = enc_input.device
+ offsets = self._codebook_offsets(device)
+
+ K = beam_size
+ BK = B * K
+
+ memory = memory.unsqueeze(1).expand(-1, K, -1, -1).reshape(BK, memory.size(1), memory.size(2))
+ if enc_padding_mask is not None:
+ enc_padding_mask = enc_padding_mask.unsqueeze(1).expand(-1, K, -1).reshape(BK, -1)
+
+ scores = torch.zeros(B, K, device=device)
+ scores[:, 1:] = -1e9
+
+ all_codes = torch.zeros(B, K, 0, dtype=torch.long, device=device)
+
+ dec_tokens = torch.full((BK, 1), self.BOS_TOKEN_ID, dtype=torch.long, device=device)
+
+ for d in range(self.sid_len):
+ hidden = self.decode(dec_tokens, memory, enc_padding_mask)
+ logits_d = self.output_heads[d](hidden[:, -1, :])
+ log_probs = log_softmax(logits_d, dim=-1)
+
+ cb_size = self.codebook_sizes[d]
+ topk_k = min(K, cb_size)
+ topk_vals, topk_idx = log_probs.topk(topk_k, dim=-1)
+
+ topk_vals = topk_vals.view(B, K, topk_k)
+ topk_idx = topk_idx.view(B, K, topk_k)
+
+ candidate_scores = scores.unsqueeze(-1) + topk_vals
+ candidate_scores = candidate_scores.view(B, K * topk_k)
+ candidate_idx = topk_idx.view(B, K * topk_k)
+ candidate_beam = (
+ torch.arange(K, device=device).unsqueeze(-1).expand(-1, topk_k).reshape(1, K * topk_k).expand(B, -1)
+ )
+
+ best_scores, best_flat = candidate_scores.topk(K, dim=-1)
+ best_beam = candidate_beam.gather(1, best_flat)
+ best_code = candidate_idx.gather(1, best_flat)
+
+ scores = best_scores
+
+ prev_codes = all_codes.gather(1, best_beam.unsqueeze(-1).expand(-1, -1, d))
+ all_codes = torch.cat([prev_codes, best_code.unsqueeze(-1)], dim=-1)
+
+ beam_offset = torch.arange(B, device=device).unsqueeze(1) * K
+ flat_beam_idx = (beam_offset + best_beam).view(BK)
+
+ dec_tokens = dec_tokens[flat_beam_idx]
+ new_token = (best_code + offsets[d]).view(BK, 1)
+ dec_tokens = torch.cat([dec_tokens, new_token], dim=1)
+
+ return all_codes, scores
diff --git a/rectools/semantic/tokenizer/__init__.py b/rectools/semantic/tokenizer/__init__.py
new file mode 100644
index 00000000..0de6bd42
--- /dev/null
+++ b/rectools/semantic/tokenizer/__init__.py
@@ -0,0 +1,9 @@
+"""Item ID tokenizer based on RQ-VAE and R-KMeans"""
+
+from .emb_dataset import EmbDataset
+from .model import SIDTokenizer
+
+__all__ = [
+ "EmbDataset",
+ "SIDTokenizer",
+]
diff --git a/rectools/semantic/tokenizer/base_quantizer.py b/rectools/semantic/tokenizer/base_quantizer.py
new file mode 100644
index 00000000..2b2cc10c
--- /dev/null
+++ b/rectools/semantic/tokenizer/base_quantizer.py
@@ -0,0 +1,37 @@
+import typing as tp
+from abc import ABC, abstractmethod
+
+import torch
+from torch import nn
+
+
+class QuantizerOutput(tp.NamedTuple):
+ """Quantizer outputs."""
+
+ sem_ids: tp.List[tp.Tuple[int, ...]]
+ loss: torch.Tensor
+
+
+class Quantizer(ABC, nn.Module):
+ """Base class for quantizers like RQ-VAE and RK-Means."""
+
+ def __init__(self, input_dim: int, codebook_sizes: tp.List[int]) -> None:
+ super().__init__()
+ self.input_dim = input_dim
+ self.codebook_sizes = codebook_sizes
+ self.codebooks = nn.ModuleList()
+
+ @abstractmethod
+ @torch.no_grad()
+ def init_codebooks(self, data: torch.Tensor) -> None:
+ """Initialize codebooks for the quantizer.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Data to initialize codebooks with
+ """
+
+ @abstractmethod
+ def forward(self, inputs: torch.Tensor) -> QuantizerOutput:
+ """Quantizer forward pass."""
diff --git a/rectools/semantic/tokenizer/emb_dataset.py b/rectools/semantic/tokenizer/emb_dataset.py
new file mode 100644
index 00000000..4ebb3b5b
--- /dev/null
+++ b/rectools/semantic/tokenizer/emb_dataset.py
@@ -0,0 +1,41 @@
+import typing as tp
+
+import numpy as np
+from torch.utils.data import Dataset
+
+
+class EmbDataset(Dataset):
+ """Dataset of item IDs paired with their embeddings.
+
+ Parameters
+ ----------
+ item_ids : list of int
+ Item identifiers.
+ embeddings : np.ndarray
+ Embedding matrix of shape ``(len(item_ids), embed_dim)``.
+ """
+
+ def __init__(
+ self,
+ item_ids: tp.List[int],
+ embeddings: np.ndarray,
+ ) -> None:
+ if len(item_ids) != len(embeddings):
+ raise ValueError(f"item_ids length ({len(item_ids)}) != embeddings length ({len(embeddings)})")
+ self.item_ids: tp.List[int] = list(item_ids)
+ self.embeddings: np.ndarray = np.asarray(embeddings)
+ self.dim = self.embeddings.shape[-1]
+
+ def __getitem__(self, idx: tp.Union[int, slice]) -> tp.Dict[str, tp.Any]:
+ if isinstance(idx, slice):
+ return {
+ "item_id": self.item_ids[idx],
+ "embed": self.embeddings[idx],
+ }
+ return {
+ "item_id": self.item_ids[idx],
+ "embed": self.embeddings[idx],
+ }
+
+ def __len__(self) -> int:
+ return len(self.embeddings)
diff --git a/rectools/semantic/tokenizer/loss.py b/rectools/semantic/tokenizer/loss.py
new file mode 100644
index 00000000..637f5fe5
--- /dev/null
+++ b/rectools/semantic/tokenizer/loss.py
@@ -0,0 +1,19 @@
+import torch
+from torch import nn
+from torch.nn.functional import mse_loss
+
+
+class CodebookLoss(nn.Module):
+ """Creates a criterion that measures codebook quantization loss between residuals and codebook embeddings."""
+
+ def __init__(self, mu: float = 1.0, beta: float = 0.25, reduction: str = "mean"):
+ super().__init__()
+ self.beta = beta
+ self.mu = mu
+ self.reduction = reduction
+
+ def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor:
+ """Run a forward pass."""
+ return self.mu * mse_loss(y_true.detach(), y_pred, reduction=self.reduction) + self.beta * mse_loss(
+ y_true, y_pred.detach(), reduction=self.reduction
+ )
diff --git a/rectools/semantic/tokenizer/model.py b/rectools/semantic/tokenizer/model.py
new file mode 100644
index 00000000..cbd22fa2
--- /dev/null
+++ b/rectools/semantic/tokenizer/model.py
@@ -0,0 +1,490 @@
+import os
+import typing as tp
+from collections import Counter, defaultdict
+from datetime import datetime
+from os import PathLike
+from os.path import join
+from pathlib import Path
+
+import numpy as np
+import torch
+from tqdm.auto import tqdm
+
+from rectools.semantic import OptimizerType, QuantizerType
+
+from .base_quantizer import Quantizer
+from .emb_dataset import EmbDataset
+from .rkmeans import RKmeans
+from .rqvae import RQVAE
+
+
+class SIDTokenizer:
+ """A class implementing semantic ID tokenizer, which can use either RK-Means or RQ-VAE as its backbone.
+
+ The tokenizer owns the quantizer lifecycle: it constructs the appropriate
+ ``Quantizer`` from configuration parameters, delegates training-related
+ calls (``init_codebooks``, forward pass) to it, and maintains the
+ ``id2sid`` / ``sid2id`` vocabulary mappings.
+
+ Parameters
+ ----------
+ input_dim : int
+ Dimensionality of incoming item embeddings.
+ codebook_sizes : tp.List[int]
+ Number of codes per codebook level.
+ codebook_dim : int
+ Dimensionality of codebook embeddings (used by RQ-VAE;
+ ignored for RK-Means, where codebook dim equals ``input_dim``).
+ hidden_dims : tp.List[int] | None
+ Hidden layer sizes for the RQ-VAE encoder/decoder.
+ Pass ``None`` for RK-Means.
+ quantizer : QuantizerType
+ Which quantizer backend to use: ``"rkmeans"`` or ``"rqvae"``.
+ device : tp.Optional[str], optional
+ Device for quantizer parameters, by default None.
+ adapter_proj_dim : int, optional
+ Number of principal components for the RQ-VAE whitening adapter,
+ by default 512. Ignored for RK-Means.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ codebook_sizes: tp.List[int],
+ codebook_dim: int = 32,
+ hidden_dims: tp.Optional[tp.List[int]] = None,
+ quantizer: QuantizerType = "rkmeans",
+ device: tp.Optional[str] = None,
+ adapter_proj_dim: int = 512,
+ ) -> None:
+ self.input_dim = input_dim
+ self.codebook_sizes = list(codebook_sizes)
+ self.codebook_dim = codebook_dim
+ self.hidden_dims = hidden_dims
+ self.quantizer_name: QuantizerType = quantizer
+ self.device = device
+ self.adapter_proj_dim = adapter_proj_dim
+
+ self._quantizer = self._build_quantizer()
+ self._quantizer.eval()
+
+ # SIDs are tuples because they are hashable -> it is easier to check for conflicts
+ # since there can be conflicting SIDs, sid2id maps to a list
+ self.id2sid: tp.Dict[int, tp.Tuple[int, ...]] = {}
+ self.sid2id: tp.Dict[tp.Tuple[int, ...], tp.List[int]] = defaultdict(list)
+
+ # ------------------------------------------------------------------
+ # Quantizer construction
+ # ------------------------------------------------------------------
+
+ def _build_quantizer(self) -> Quantizer:
+ """Construct the quantizer from stored configuration and move it to ``self.device``."""
+ quantizer: Quantizer
+ if self.quantizer_name == "rqvae":
+ quantizer = RQVAE(
+ input_dim=self.input_dim,
+ codebook_sizes=self.codebook_sizes,
+ hidden_dims=self.hidden_dims if self.hidden_dims is not None else [],
+ codebook_dim=self.codebook_dim,
+ adapter_proj_dim=self.adapter_proj_dim,
+ )
+ else:
+ quantizer = RKmeans(
+ input_dim=self.input_dim,
+ codebook_sizes=self.codebook_sizes,
+ )
+ return quantizer.to(self.device)
+
+ # ------------------------------------------------------------------
+ # Training helpers (delegated to the quantizer)
+ # ------------------------------------------------------------------
+
+ @property
+ def quantizer(self) -> Quantizer:
+ """Access the underlying quantizer module."""
+ return self._quantizer
+
+ def init_codebooks(self, dataset: EmbDataset) -> None:
+ """Initialize quantizer codebooks from an ``EmbDataset``.
+
+ Parameters
+ ----------
+ dataset : EmbDataset
+ Dataset whose embeddings are used for codebook initialization.
+ """
+ embs = torch.Tensor(dataset[:]["embed"]).to(self.device)
+ self._quantizer.init_codebooks(embs)
+
+ def __call__(self, batch: tp.Dict[str, tp.Any]) -> torch.Tensor:
+ """Run a forward pass through the quantizer on a batch and update ``id2sid``.
+
+ Parameters
+ ----------
+ batch : tp.Dict[str, tp.Any]
+ A dict with ``"item_id"`` and ``"embed"`` keys, as returned by
+ ``EmbDataset.__getitem__``.
+
+ Returns
+ -------
+ torch.Tensor
+ Scalar loss from the quantizer.
+ """
+ item_ids = batch["item_id"]
+ if isinstance(item_ids, int):
+ item_ids = [item_ids]
+
+ embs = torch.Tensor(batch["embed"]).to(self.device)
+ if embs.dim() == 1:
+ embs = embs.unsqueeze(0)
+
+ out = self._quantizer(embs)
+
+ for idx, item in enumerate(item_ids):
+ self.id2sid[item] = out.sem_ids[idx]
+
+ return out.loss
+
+ def train(self) -> "SIDTokenizer":
+ """Set the quantizer to training mode."""
+ self._quantizer.train()
+ return self
+
+ def eval(self) -> "SIDTokenizer":
+ """Set the quantizer to evaluation mode."""
+ self._quantizer.eval()
+ return self
+
+ def parameters(self) -> tp.Any:
+ """Return quantizer parameters (for optimizer construction)."""
+ return self._quantizer.parameters()
+
+ # ------------------------------------------------------------------
+ # Fit
+ # ------------------------------------------------------------------
+
+ @staticmethod
+ def _get_collision_rate(id2sid: tp.Dict[int, tp.Tuple[int, ...]]) -> float:
+ """Fraction of items whose SID collides with at least one other item."""
+ sid_counts = Counter(id2sid.values())
+ n_colliding = sum(c for c in sid_counts.values() if c > 1)
+ return n_colliding / len(id2sid) if id2sid else 1.0
+
+ @torch.no_grad()
+ def _get_id2sid(self, dataset: EmbDataset, batch_size: int = 2048) -> tp.Dict[int, tp.Tuple[int, ...]]:
+ """Run the quantizer over ``dataset`` in eval mode and return an id2sid mapping."""
+ was_training = self._quantizer.training
+ self._quantizer.eval()
+ id2sid: tp.Dict[int, tp.Tuple[int, ...]] = {}
+ for start_idx in range(0, len(dataset), batch_size):
+ end_idx = min(len(dataset), start_idx + batch_size)
+ batch = dataset[start_idx:end_idx]
+ batch_embeds = torch.Tensor(batch["embed"]).to(self.device)
+ out = self._quantizer(batch_embeds)
+ for idx, item in enumerate(batch["item_id"]):
+ id2sid[item] = out.sem_ids[idx]
+ if was_training:
+ self._quantizer.train()
+ return id2sid
+
+ def fit( # pylint: disable=too-many-locals
+ self,
+ item_ids: tp.List[int],
+ embeddings: np.ndarray,
+ init_max_items: int = 10_000,
+ max_epochs: int = 1_000,
+ patience: int = 100,
+ optimizer: OptimizerType = "adamw",
+ lr: float = 1e-3,
+ weight_decay: float = 1e-4,
+ batch_size: int = 2048,
+ save_dir: str = "checkpoints",
+ ) -> "SIDTokenizer":
+ """Train the quantizer on pre-computed item embeddings.
+
+ After training, ``id2sid`` is populated for every item in ``item_ids``.
+
+ Parameters
+ ----------
+ item_ids : tp.List[int]
+ Item identifiers matching rows of ``embeddings``.
+ embeddings : np.ndarray
+ Embedding matrix of shape ``(len(item_ids), embed_dim)``.
+ init_max_items : int, optional
+ Maximum number of items used for codebook initialization,
+ by default 10_000.
+ max_epochs : int, optional
+ Maximum training epochs, by default 1_000.
+ patience : int, optional
+ Early stopping patience (epochs without collision rate
+ improvement), by default 100.
+ optimizer : OptimizerType, optional
+ Optimizer name, by default ``"adamw"``.
+ lr : float, optional
+ Learning rate, by default 1e-3.
+ weight_decay : float, optional
+ Weight decay, by default 1e-4.
+ batch_size : int, optional
+ Training batch size, by default 2048.
+ save_dir : str, optional
+ Directory for saving best checkpoint, by default ``"checkpoints"``.
+
+ Returns
+ -------
+ SIDTokenizer
+ ``self``, with trained quantizer and populated ``id2sid``.
+ """
+ dataset = EmbDataset(item_ids, embeddings)
+
+ # Initialize codebooks from a (possibly smaller) subset
+ init_dataset = EmbDataset(item_ids[:init_max_items], embeddings[:init_max_items])
+ self.init_codebooks(init_dataset)
+
+ opt: torch.optim.Optimizer
+ if optimizer.lower() == "adam":
+ opt = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
+ elif optimizer.lower() == "adagrad":
+ opt = torch.optim.Adagrad(self.parameters(), lr=lr, weight_decay=weight_decay)
+ else:
+ opt = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay)
+
+ best_collision_rate = float("inf")
+ steps_no_improve = 0
+ best_path = Path(
+ join(
+ save_dir,
+ f"tokenizer_{self.quantizer_name}",
+ f"{datetime.now().strftime('%d%b%Y-%H.%M.%S.%f')}",
+ "best_perf.pt",
+ )
+ )
+ os.makedirs(best_path.parent, exist_ok=True)
+
+ self.train()
+ with tqdm(
+ total=max_epochs * (len(dataset) // batch_size + (len(dataset) % batch_size > 0)),
+ ) as pbar:
+ for epoch_no in range(max_epochs):
+ pbar.set_description(f"Tokenizer train {epoch_no} / {max_epochs}")
+ for start_idx in range(0, len(dataset), batch_size):
+ end_idx = min(start_idx + batch_size, len(dataset))
+ batch = dataset[start_idx:end_idx]
+ batch_embeds = torch.Tensor(batch["embed"]).to(self.device)
+ out = self._quantizer(batch_embeds)
+ loss = out.loss
+ opt.zero_grad()
+ loss.backward()
+ opt.step()
+
+ metrics_dict = {"loss": loss.item()}
+ pbar.set_postfix({metric: f"{value:.4f}" for metric, value in metrics_dict.items()})
+ pbar.update(1)
+
+ # Compute collision rate
+ id2sid = self._get_id2sid(dataset, batch_size=batch_size)
+ collision_rate = self._get_collision_rate(id2sid)
+
+ improved = collision_rate < best_collision_rate
+ if improved:
+ best_collision_rate = collision_rate
+ steps_no_improve = 0
+ torch.save(self._quantizer.state_dict(), best_path)
+ else:
+ steps_no_improve += 1
+
+ self.train()
+
+ if steps_no_improve >= patience:
+ print(f"Early stopping at epoch {epoch_no} (patience exceeded {patience})")
+ break
+
+ print(f"Best model's collision rate: {best_collision_rate * 100:.2f}%")
+ self._quantizer.load_state_dict(torch.load(best_path, map_location=self.device))
+ self.eval()
+ self.id2sid = self._get_id2sid(dataset, batch_size=batch_size)
+ return self
+
+ # ------------------------------------------------------------------
+ # Vocabulary extension
+ # ------------------------------------------------------------------
+
+ @torch.no_grad()
+ def extend(self, dataset: EmbDataset) -> None:
+ """Tokenize new items from an EmbDataset and add them to the vocabulary.
+
+ Items already present in id2sid are skipped.
+ """
+ for start_idx in range(0, len(dataset), 256):
+ end_idx = min(start_idx + 256, len(dataset))
+ batch = dataset[start_idx:end_idx]
+
+ item_ids = batch["item_id"]
+ if isinstance(item_ids, int):
+ item_ids = [item_ids]
+
+ new_mask = [i for i, item in enumerate(item_ids) if item not in self.id2sid]
+ if not new_mask:
+ continue
+
+ new_embs = torch.Tensor(batch["embed"]).to(self.device)
+ if new_embs.dim() == 1:
+ new_embs = new_embs.unsqueeze(0)
+ new_embs = new_embs[new_mask]
+
+ out = self._quantizer(new_embs)
+ sids = out.sem_ids
+ for j, mask_idx in enumerate(new_mask):
+ item = item_ids[mask_idx]
+ sid = sids[j]
+ self.id2sid[item] = sid
+ if self.sid2id:
+ self.sid2id[sid].append(item)
+
+ # ------------------------------------------------------------------
+ # Tokenize / decode
+ # ------------------------------------------------------------------
+
+ def _get_sid(self, item: int) -> tp.Tuple[int, ...]:
+ if item in self.id2sid:
+ return self.id2sid[item]
+ raise ValueError(
+ f"Item with ID {item} is out of vocabulary. " "Use extend() with an EmbDataset containing this item first."
+ )
+
+ def _get_id(
+ self,
+ sid: tp.Tuple[int, ...],
+ default_value: tp.Optional[int] = None,
+ rng: tp.Optional[np.random.RandomState] = None,
+ ) -> tp.Optional[int]:
+ if len(self.sid2id) == 0:
+ for item_id, item_sid in self.id2sid.items():
+ self.sid2id[item_sid].append(item_id)
+ if sid in self.sid2id:
+ if rng is None:
+ rng = np.random
+ return int(rng.choice(self.sid2id[sid]))
+ return default_value
+
+ def tokenize( # pylint: disable=redefined-builtin
+ self, input: int | tp.Iterable[int]
+ ) -> tp.Tuple[int, ...] | tp.List[tp.Tuple[int, ...]]:
+ """Convert incoming item IDs to corresponding semantic IDs.
+
+ Parameters
+ ----------
+ input : int | tp.Iterable[int]
+ Item ID(s).
+
+ Returns
+ -------
+ tp.Tuple[int, ...] | tp.List[tp.Tuple[int, ...]]
+ Item SID(s).
+ """
+ if isinstance(input, int):
+ return self._get_sid(input)
+
+ output = []
+ for item in input:
+ output.append(self._get_sid(item))
+ return output
+
+ def decode( # pylint: disable=redefined-builtin
+ self,
+ input: tp.Tuple[int, ...] | tp.Iterable[tp.Tuple[int, ...]],
+ default_value: tp.Optional[int] = None,
+ rng: tp.Optional[np.random.RandomState] = None,
+ ) -> tp.Optional[int] | tp.List[tp.Optional[int]]:
+ """Decode incoming SID(s) to corresponding item IDs.
+
+ Parameters
+ ----------
+ input : tp.Tuple[int, ...] | tp.Iterable[tp.Tuple[int, ...]]
+ Either a single SID or an iterable of SIDs.
+ default_value : tp.Optional[int], optional
+ What to return when there is no item corresponding to given SID, by default None
+ rng : np.random.RandomState, optional
+ Random generator used to resolve SID collisions. If omitted, the
+ global NumPy random generator is used.
+
+ Returns
+ -------
+ tp.Optional[int] | tp.List[tp.Optional[int]]
+ Item ID(s) corresponding to input SID(s).
+ """
+ # check if the input is only a single SID
+ if isinstance(input, tuple) and all(map(lambda x: isinstance(x, int), input)):
+ return self._get_id(input, default_value=default_value, rng=rng)
+ # type: ignore[arg-type]
+ return list(map(lambda item: self._get_id(item, default_value=default_value, rng=rng), input))
+
+ # ------------------------------------------------------------------
+ # Save / load
+ # ------------------------------------------------------------------
+
+ def save(self, path: tp.Union[str, PathLike[str]]) -> None:
+ """Save tokenizer to the specified path.
+
+ The checkpoint is a dict containing quantizer configuration,
+ quantizer ``state_dict``, and the ``id2sid`` vocabulary mapping.
+
+ Parameters
+ ----------
+ path : str or PathLike
+ File path to write the checkpoint to.
+ """
+ checkpoint = {
+ "quantizer_name": self.quantizer_name,
+ "input_dim": self.input_dim,
+ "codebook_sizes": self.codebook_sizes,
+ "codebook_dim": self.codebook_dim,
+ "hidden_dims": self.hidden_dims,
+ "adapter_proj_dim": self.adapter_proj_dim,
+ "quantizer_state_dict": self._quantizer.state_dict(),
+ "id2sid": self.id2sid,
+ }
+ torch.save(checkpoint, path)
+
+ @classmethod
+ def load(
+ cls,
+ path: tp.Union[str, PathLike[str]],
+ device: tp.Optional[str] = None,
+ map_location: tp.Optional[str] = None,
+ ) -> "SIDTokenizer":
+ """Load serialized SIDTokenizer.
+
+ Parameters
+ ----------
+ path : str or PathLike
+ Location of the serialized tokenizer.
+ device : tp.Optional[str], optional
+ Device to load tokenizer's quantizer to, by default None.
+ map_location : tp.Optional[str], optional
+ ``map_location`` forwarded to ``torch.load``. If provided,
+ also used as the device. By default None.
+
+ Returns
+ -------
+ SIDTokenizer
+ A de-serialized tokenizer.
+ """
+ effective_device = map_location or device
+ checkpoint = torch.load(path, map_location=effective_device, weights_only=False)
+
+ tokenizer = cls(
+ input_dim=checkpoint["input_dim"],
+ codebook_sizes=checkpoint["codebook_sizes"],
+ codebook_dim=checkpoint["codebook_dim"],
+ hidden_dims=checkpoint["hidden_dims"],
+ quantizer=checkpoint["quantizer_name"],
+ device=effective_device,
+ adapter_proj_dim=checkpoint["adapter_proj_dim"],
+ )
+ tokenizer._quantizer.load_state_dict(checkpoint["quantizer_state_dict"])
+ tokenizer._quantizer.eval()
+ tokenizer.id2sid = checkpoint["id2sid"]
+
+ return tokenizer
+
+ def __len__(self) -> int:
+ return len(set(self.id2sid.values()))
diff --git a/rectools/semantic/tokenizer/rkmeans.py b/rectools/semantic/tokenizer/rkmeans.py
new file mode 100644
index 00000000..18cc3019
--- /dev/null
+++ b/rectools/semantic/tokenizer/rkmeans.py
@@ -0,0 +1,171 @@
+import typing as tp
+
+import numpy as np
+import torch
+from k_means_constrained import KMeansConstrained
+from torch import nn
+from torch.nn.functional import mse_loss, normalize
+from tqdm.auto import tqdm
+
+from .base_quantizer import Quantizer, QuantizerOutput
+
+
+class Codebook(nn.Module):
+ """An RK-Means codebook. Takes in item embeddings or their residuals from previous codebooks and outputs
+ codes corresponding to embeddings closest to input residuals.
+
+ Parameters
+ ----------
+ code_dim : int
+ Size of codebook embeddings.
+ n_codes : int
+ Number of codebook embeddings.
+ """
+
+ def __init__(
+ self,
+ code_dim: int,
+ n_codes: int,
+ ) -> None:
+ super().__init__()
+ self.code_dim = code_dim
+ self.n_codes = n_codes
+ self.code_embs = nn.Embedding(n_codes, code_dim)
+ self.kmeans_initted_ = False
+
+ def init_from_centroids(self, centroids: torch.Tensor) -> None:
+ """Initialize codebook using predetermined centroids as embeddings.
+
+ Parameters
+ ----------
+ centroids : torch.Tensor
+ precomputed centroids of shape [n_codes, emb_dim]
+ """
+ assert centroids.shape == (self.n_codes, self.code_dim)
+ with torch.no_grad():
+ self.code_embs.weight.data.copy_(centroids)
+ self.kmeans_initted_ = True
+
+ def forward(self, inputs: torch.Tensor) -> tp.Tuple:
+ """Run a forward pass.
+
+ Parameters
+ ----------
+ inputs : torch.Tensor
+ Input residuals.
+
+ Returns
+ -------
+ tp.Tuple
+ Corresponding codes, quantized reprezentations, residuals between quantized
+ representations and inputs.
+ """
+ diffs = inputs.unsqueeze(1) - self.code_embs.weight.unsqueeze(0)
+ dists = torch.linalg.vector_norm(diffs, dim=-1) # pylint: disable=not-callable
+ codes = dists.argmin(dim=-1)
+
+ quantized = self.code_embs.weight[codes, :]
+ residuals = inputs - quantized
+
+ # Straight-through estimator: forward value is quantized,
+ # but gradients pass through to the encoder as if quantized == inputs.
+ quantized = inputs + (quantized - inputs).detach()
+
+ return codes, quantized, residuals
+
+
+class RKmeans(Quantizer):
+ """Apply Residual Mini-Batch K-Means (RK-Means) -- a multi-level vector quantizer
+ that applies quantization on residuals to generate a tuple of codewords (aka Semantic IDs).
+
+ Parameters
+ ----------
+ input_dim : int
+ Input embedding dimension.
+ codebook_sizes : tp.List[int]
+ Number of embeddings contained in RK-Means codebooks.
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ codebook_sizes: tp.List[int],
+ ) -> None:
+ super().__init__(
+ input_dim=input_dim,
+ codebook_sizes=codebook_sizes,
+ )
+ for n_codes in codebook_sizes:
+ self.codebooks.append(Codebook(self.input_dim, n_codes))
+
+ @torch.no_grad()
+ def init_codebooks(self, data: torch.Tensor) -> None:
+ """Initialize codebook centroids using constrained k-means.
+
+ Runs the encoder on `data`, then for each codebook level fits
+ KMeansConstrained (with size_min per cluster) on the residuals
+ and copies the resulting centroids into the codebook embeddings.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Input embeddings.
+ """
+ residuals = normalize(data).cpu().numpy()
+
+ for layer in tqdm(self.codebooks, desc="Initializing codebooks with constrained k-means"):
+ size_min = min(len(data) // (layer.n_codes * 2), 50)
+ size_max = size_min * 4 if layer.n_codes * size_min * 4 > len(data) else len(data)
+ km = KMeansConstrained(
+ n_clusters=layer.n_codes,
+ size_min=size_min,
+ size_max=size_max,
+ random_state=0,
+ max_iter=10,
+ n_init=10,
+ n_jobs=10,
+ verbose=False,
+ )
+ km.fit(residuals)
+ centroids = torch.from_numpy(km.cluster_centers_).to(data.device)
+ layer.init_from_centroids(centroids)
+
+ # Compute residuals for the next level
+ assignments = km.labels_
+ residuals = residuals - km.cluster_centers_[assignments]
+ residuals = residuals / (np.linalg.norm(residuals, axis=1, keepdims=True) + 1e-12)
+
+ def forward(self, inputs: torch.Tensor) -> QuantizerOutput:
+ """Run a forward pass
+
+ Parameters
+ ----------
+ inputs : torch.Tensor
+ Input item embeddings.
+
+ Returns
+ -------
+ tp.Tuple
+ Semantic IDs corresponding to input embeddings, reconstruction loss.
+ """
+ res_prev = normalize(inputs)
+ embeds, sem_ids = [], []
+ for layer in self.codebooks:
+ codes, quantized, res = layer(res_prev)
+ res_prev = res
+ res_prev = normalize(res_prev)
+ embeds.append(quantized)
+ sem_ids.append(codes)
+
+ # embeds: [B] x num_codebooks -> [B, num_codebooks] -> [B]
+ embeds_tensor = torch.stack(embeds, dim=1).sum(dim=1)
+ sem_ids_tensor = torch.stack(sem_ids, dim=1).to(torch.int32)
+
+ # Convert semids from tensors to tuples of ints
+ sem_ids = [
+ tuple(sem_ids_tensor[idx, :].flatten().detach().cpu().tolist()) for idx in range(sem_ids_tensor.size(0))
+ ]
+
+ # Reconstruction loss
+ loss = mse_loss(inputs, embeds_tensor, reduction="mean")
+ return QuantizerOutput(sem_ids=sem_ids, loss=loss)
diff --git a/rectools/semantic/tokenizer/rqvae.py b/rectools/semantic/tokenizer/rqvae.py
new file mode 100644
index 00000000..7ae37d57
--- /dev/null
+++ b/rectools/semantic/tokenizer/rqvae.py
@@ -0,0 +1,304 @@
+import typing as tp
+
+import torch
+from k_means_constrained import KMeansConstrained
+from torch import nn
+from torch.nn.functional import mse_loss
+from tqdm.auto import tqdm
+
+from rectools.semantic.modules import MLP
+
+from .base_quantizer import Quantizer, QuantizerOutput
+from .loss import CodebookLoss
+
+
+def _get_pca_projection(embeddings: torch.Tensor, out_dim: int) -> torch.Tensor:
+ centered = embeddings - embeddings.mean(dim=0)
+ _, _, vh = torch.linalg.svd(centered, full_matrices=True) # pylint: disable=not-callable
+ return vh[:out_dim].T.contiguous()
+
+
+class WhiteningAdapter(nn.Module):
+ r"""Applies an adapter consisting of PCA projection and an optional linear transformation to incoming item
+ embeddings.
+
+ The adapter itself is a linear transformation that modifies input embedddings `x` in the following manner:
+
+ .. math::
+ \text{out}(x) = W \times (x - b)
+
+ Where :math:`W` is weights matrix initialized as principal component matrix and :math:`b` is a vector of
+ biases initialized as mean of embedding.
+
+ Parameters
+ ----------
+ adapter_type : tp.Literal["ffn", "identity"], optional
+ Transformation applied after PCA projection, by default "identity"
+ emb_dim : int, optional
+ Dimension of incoming embeddings, by default 768
+ proj_dim : int, optional
+ How many principal components to use, by default 512
+ hidden_units : int, optional
+ Size of hidden layer in linear adapter, by default 256
+ dropout : float, optional
+ Dropout probability for linear adapter, by default 0.1
+ device : str, optional
+ Device the adapter is on, by default "cpu"
+ """
+
+ def __init__(
+ self,
+ adapter_type: tp.Literal["ffn", "identity"] = "identity",
+ emb_dim: int = 768,
+ proj_dim: int = 512,
+ hidden_units: int = 256,
+ dropout: float = 0.1,
+ device: str = "cpu",
+ ) -> None:
+ super().__init__()
+
+ self.adapter_type = adapter_type
+ self.hidden_units = hidden_units
+ self.emb_dim = emb_dim
+ self.proj_dim = proj_dim
+ self.dropout = dropout
+ self.device = device
+ self.weight = nn.Parameter(torch.empty(self.emb_dim, self.proj_dim))
+ self.bias = nn.Parameter(torch.empty(self.emb_dim))
+
+ self.head: nn.Module
+ if self.adapter_type == "ffn":
+ self.head = MLP(self.proj_dim, [self.hidden_units], self.proj_dim, dropout=self.dropout)
+ else:
+ self.head = nn.Identity()
+
+ def freeze_adapter(self) -> None:
+ """Freeze PCA transform's weights and biases."""
+ self.bias.requires_grad = False
+ self.weight.requires_grad = False
+
+ @torch.no_grad()
+ def init_from_embeds(self, embeddings: torch.Tensor) -> None:
+ """Initialize PCA projection weights using item metadata embeddigns.
+
+ Parameters
+ ----------
+ embeddings : torch.Tensor
+ Item metadata embeddings
+ """
+ proj = _get_pca_projection(embeddings, self.proj_dim)
+
+ self.bias.data.copy_(embeddings.mean(dim=0))
+ self.weight.data.copy_(proj)
+ self.freeze_adapter()
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ """Adapter forward pass."""
+ # PCA projection is precision-sensitive (1024-dim matmul accumulates
+ # bf16 rounding errors). Force fp32 even inside autocast.
+ with torch.amp.autocast(self.device, enabled=False):
+ projected = (inputs.float() - self.bias.float()) @ self.weight.float()
+ return self.head(projected)
+
+
+class Codebook(nn.Module):
+ """An RQ-VAE codebook. Takes in RQ-VAE encoder outputs or their residuals from previous codebooks and outputs
+ codes corresponding to embeddings closest to input residuals.
+
+ Parameters
+ ----------
+ code_dim : int
+ Size of codebook embeddings.
+ n_codes : int
+ Number of codebook embeddings.
+ """
+
+ def __init__(
+ self,
+ code_dim: int,
+ n_codes: int,
+ ) -> None:
+ super().__init__()
+ self.code_dim = code_dim
+ self.n_codes = n_codes
+ self.code_embs = nn.Embedding(n_codes, code_dim)
+ self.cb_loss = CodebookLoss()
+ self.kmeans_initted_ = False
+
+ def init_from_centroids(self, centroids: torch.Tensor) -> None:
+ """Initialize codebook using predetermined centroids as embeddings.
+
+ Parameters
+ ----------
+ centroids : torch.Tensor
+ precomputed centroids of shape [n_codes, emb_dim]
+ """
+ assert centroids.shape == (self.n_codes, self.code_dim)
+ with torch.no_grad():
+ self.code_embs.weight.data.copy_(centroids)
+ self.kmeans_initted_ = True
+
+ def forward(self, inputs: torch.Tensor) -> tp.Tuple:
+ """Codebook forward pass.
+
+ Parameters
+ ----------
+ inputs : torch.Tensor
+ Input residuals expected to be of shape [B, emb_dim]
+
+ Returns
+ -------
+ tp.Tuple
+ Codes corresponding to embeddings closest to input residuals, embeddings themselves, resuduals,
+ and codebook loss.
+ """
+ # inputs: [B, emb_dim] -> [B, 1, emb_dim]
+ # code_embs: [cb_size, emb_dim] -> [1, cb_size, emb_dim]
+ # This maps to diffs: [B, cb_size, emb_dim]
+ diffs = inputs.unsqueeze(1) - self.code_embs.weight.unsqueeze(0)
+ dists = torch.linalg.vector_norm(diffs, dim=-1) # pylint: disable=not-callable
+ codes = dists.argmin(dim=-1) # out: [B]
+
+ quantized = self.code_embs(codes)
+ residuals = inputs - quantized
+
+ # Straight-through estimator: forward value is quantized,
+ # but gradients pass through to the encoder as if quantized == inputs.
+ quantized_st = inputs + (quantized - inputs).detach()
+
+ loss = self.cb_loss(inputs, quantized)
+
+ return codes, quantized_st, residuals, loss
+
+
+class RQVAE(Quantizer):
+ """Apply Residual-Quantized Variational AutoEncoder (RQ-VAE) -- a multi-level vector quantizer
+ that applies quantization on residuals to generate a tuple of codewords (aka Semantic IDs).
+ The Autoencoder is jointly trained by updating the quantization codebook and the DNN
+ encoder-decoder parameters.
+
+ Parameters
+ ----------
+ input_dim : int
+ Input embeddings size.
+ codebook_dim : int
+ Size of codebook embeddings.
+ hidden_dims : tp.List[int]
+ Encoder and decoder hidden layer sizes.
+ codebook_sizes : tp.List[int]
+ Number of embeddings each codebook contains.
+ adapter_type : tp.Literal["ffn", "identity"], optional
+ Transformation applied after PCA projection, by default "identity"
+ adapter_proj_dim : int, optional
+ How many principal components to use, by default 512
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ codebook_dim: int,
+ hidden_dims: tp.List[int],
+ codebook_sizes: tp.List[int],
+ adapter_type: tp.Literal["ffn", "identity"] = "identity",
+ adapter_proj_dim: int = 512,
+ ) -> None:
+ super().__init__(
+ input_dim=input_dim,
+ codebook_sizes=codebook_sizes,
+ )
+ self.codebook_dim = codebook_dim
+ for n_codes in codebook_sizes:
+ self.codebooks.append(Codebook(codebook_dim, n_codes))
+
+ self.adapter_type = adapter_type
+ self.adapter_proj_dim = adapter_proj_dim
+ self.adapter = WhiteningAdapter(
+ adapter_type=adapter_type,
+ emb_dim=input_dim,
+ proj_dim=adapter_proj_dim,
+ )
+
+ self.encoder = MLP(adapter_proj_dim, hidden_dims, codebook_dim)
+ self.decoder = MLP(codebook_dim, hidden_dims[::-1], adapter_proj_dim)
+
+ @torch.no_grad()
+ def init_codebooks(self, data: torch.Tensor) -> None:
+ """Initialize codebook centroids using constrained k-means.
+
+ Runs the encoder on `data`, then for each codebook level fits
+ KMeansConstrained (with size_min per cluster) on the residuals
+ and copies the resulting centroids into the codebook embeddings.
+
+ Parameters
+ ----------
+ data : torch.Tensor
+ Item metadata embeddings.
+ """
+ self.adapter.init_from_embeds(data)
+ embs = self.adapter(data)
+ encoded = self.encoder(embs).cpu().numpy()
+ residuals = encoded
+
+ for layer in tqdm(self.codebooks, desc="Initializing codebooks with constrained k-means"):
+ size_min = min(len(data) // (layer.n_codes * 2), 50)
+ size_max = size_min * 4 if layer.n_codes * size_min * 4 > len(data) else len(data)
+ km = KMeansConstrained(
+ n_clusters=layer.n_codes,
+ size_min=size_min,
+ size_max=size_max,
+ random_state=0,
+ max_iter=10,
+ n_init=10,
+ n_jobs=10,
+ verbose=False,
+ )
+ km.fit(residuals)
+ centroids = torch.from_numpy(km.cluster_centers_).to(data.device)
+ layer.init_from_centroids(centroids)
+
+ # Compute residuals for the next level
+ assignments = km.labels_
+ residuals = residuals - km.cluster_centers_[assignments]
+
+ def forward(self, inputs: torch.Tensor) -> QuantizerOutput:
+ """RQ-VAE forward pass.
+
+ Parameters
+ ----------
+ inputs : torch.Tensor
+ Input embeddings
+
+ Returns
+ -------
+ tp.Tuple
+ Semantic IDs corresponding to input embeddings, RQ-VAE loss (codebook losses +
+ reconstruction loss).
+ """
+ inputs_adapted = self.adapter(inputs)
+ inputs_enc = self.encoder(inputs_adapted)
+
+ loss = torch.tensor(0.0, device=next(self.codebooks.parameters()).device)
+
+ res_prev = inputs_enc
+ embeds, sem_ids = [], []
+ for layer in self.codebooks:
+ codes, quantized_st, res, loss_upd = layer(res_prev)
+ res_prev = res
+ loss += loss_upd
+ embeds.append(quantized_st)
+ sem_ids.append(codes)
+
+ embeds_tensor = torch.stack(embeds, dim=1).sum(dim=1)
+ sem_ids_tensor = torch.stack(sem_ids, dim=1).to(torch.int32)
+
+ # Convert semids from tensors to tuples of ints
+ sem_ids = [
+ tuple(sem_ids_tensor[idx, :].flatten().detach().cpu().tolist()) for idx in range(sem_ids_tensor.size(0))
+ ]
+
+ embeds_dec = self.decoder(embeds_tensor)
+
+ # Reconstruction loss
+ loss += mse_loss(inputs_adapted, embeds_dec, reduction="mean")
+
+ return QuantizerOutput(sem_ids=sem_ids, loss=loss)
diff --git a/setup.cfg b/setup.cfg
index cbd220b9..d1b23834 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -49,6 +49,9 @@ per-file-ignores =
rectools/models/nn/dssm.py: D101,D102,N812
rectools/dataset/torch_datasets.py: D101,D102
rectools/models/implicit_als.py: N806
+ rectools/semantic/modules/transformer_blocks.py: N806
+ rectools/semantic/tiger/module.py: D101,D102,N806
+ rectools/semantic/tiger/lightning.py: D101,D102
rectools/fast_transformers/net.py: N806
rectools/fast_transformers/unisrec/net.py: N806
diff --git a/tests/semantic/__init__.py b/tests/semantic/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/semantic/data_handling/__init__.py b/tests/semantic/data_handling/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/semantic/data_handling/test_dataset.py b/tests/semantic/data_handling/test_dataset.py
new file mode 100644
index 00000000..bafa7f75
--- /dev/null
+++ b/tests/semantic/data_handling/test_dataset.py
@@ -0,0 +1,170 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import pandas as pd
+import pytest
+from pytorch_lightning import seed_everything
+
+from rectools.semantic.data_handling.dataset import PaddingCollateFn, TIGERDataset
+from rectools.semantic.tokenizer.emb_dataset import EmbDataset
+from rectools.semantic.tokenizer.model import SIDTokenizer
+
+
+@pytest.fixture
+def trained_tokenizer() -> SIDTokenizer:
+ seed_everything(42, workers=True)
+ rng = np.random.RandomState(42)
+ item_ids = list(range(1, 21))
+ embeddings = rng.randn(20, 16).astype(np.float32)
+ dataset = EmbDataset(item_ids, embeddings)
+ tok = SIDTokenizer(
+ codebook_dim=16,
+ hidden_dims=None,
+ codebook_sizes=[4, 4],
+ input_dim=16,
+ device="cpu",
+ quantizer="rkmeans",
+ )
+ tok.init_codebooks(dataset)
+ batch = dataset[:]
+ tok(batch)
+ return tok
+
+
+@pytest.fixture
+def interactions() -> pd.DataFrame:
+ return pd.DataFrame(
+ {
+ "user_id": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
+ "item_id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
+ "timestamp": list(range(1, 13)),
+ }
+ )
+
+
+class TestTIGERDataset: # pylint: disable=redefined-outer-name
+ def test_eval_mode_len(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None:
+ ds = TIGERDataset(
+ interactions=interactions,
+ tokenizer=trained_tokenizer,
+ codebook_sizes=[4, 4],
+ max_length=10,
+ codeword_offset=2,
+ bos_token_id=1,
+ only_last=True,
+ eval_mode=True,
+ )
+ # 3 users, each gets one sample
+ assert len(ds) == 3
+
+ def test_eval_mode_getitem(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None:
+ ds = TIGERDataset(
+ interactions=interactions,
+ tokenizer=trained_tokenizer,
+ codebook_sizes=[4, 4],
+ max_length=10,
+ codeword_offset=2,
+ bos_token_id=1,
+ only_last=True,
+ eval_mode=True,
+ )
+ sample = ds[0]
+ assert "input_ids" in sample
+ assert "labels" in sample
+ assert isinstance(sample["input_ids"], np.ndarray)
+ assert isinstance(sample["labels"], (int, np.integer))
+
+ def test_train_mode_only_last(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None:
+ ds = TIGERDataset(
+ interactions=interactions,
+ tokenizer=trained_tokenizer,
+ codebook_sizes=[4, 4],
+ max_length=10,
+ codeword_offset=2,
+ bos_token_id=1,
+ only_last=True,
+ eval_mode=False,
+ )
+ assert len(ds) == 3
+ sample = ds[0]
+ assert "input_ids" in sample
+ assert "dec_input" in sample
+ assert "labels" in sample
+
+ def test_train_mode_all_subsequences(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None:
+ ds = TIGERDataset(
+ interactions=interactions,
+ tokenizer=trained_tokenizer,
+ codebook_sizes=[4, 4],
+ max_length=10,
+ codeword_offset=2,
+ bos_token_id=1,
+ only_last=False,
+ eval_mode=False,
+ )
+ # Each user has 4 items -> 3 subsequences per user -> 9 total
+ assert len(ds) == 9
+
+ def test_max_users(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None:
+ ds = TIGERDataset(
+ interactions=interactions,
+ tokenizer=trained_tokenizer,
+ codebook_sizes=[4, 4],
+ max_length=10,
+ codeword_offset=2,
+ bos_token_id=1,
+ only_last=True,
+ eval_mode=True,
+ max_users=2,
+ )
+ assert len(ds) == 2
+
+ def test_dec_input_starts_with_bos(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None:
+ ds = TIGERDataset(
+ interactions=interactions,
+ tokenizer=trained_tokenizer,
+ codebook_sizes=[4, 4],
+ max_length=10,
+ codeword_offset=2,
+ bos_token_id=1,
+ only_last=True,
+ eval_mode=False,
+ )
+ sample = ds[0]
+ assert sample["dec_input"][0] == 1 # BOS token
+
+
+class TestPaddingCollateFn:
+ def test_pads_sequences(self) -> None:
+ batch = [
+ {"input_ids": np.array([1, 2, 3]), "labels": np.array([10, 11])},
+ {"input_ids": np.array([4, 5]), "labels": np.array([12, 13])},
+ ]
+ collate = PaddingCollateFn(padding_value=0, labels_padding_value=-100)
+ result = collate(batch)
+ assert result["input_ids"].shape == (2, 3)
+ assert result["labels"].shape == (2, 2)
+ # Second sequence should be padded
+ assert result["input_ids"][1, 2].item() == 0
+
+ def test_scalar_labels(self) -> None:
+ batch = [
+ {"input_ids": np.array([1, 2, 3]), "labels": 10},
+ {"input_ids": np.array([4, 5, 6]), "labels": 20},
+ ]
+ collate = PaddingCollateFn(padding_value=0)
+ result = collate(batch)
+ assert result["labels"].shape == (2,)
+ assert result["labels"].tolist() == [10, 20]
diff --git a/tests/semantic/data_handling/test_k_core.py b/tests/semantic/data_handling/test_k_core.py
new file mode 100644
index 00000000..956b6ec8
--- /dev/null
+++ b/tests/semantic/data_handling/test_k_core.py
@@ -0,0 +1,87 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pandas as pd
+import pytest
+
+from rectools.semantic.data_handling.k_core import k_core
+
+
+class TestKCore:
+ @pytest.fixture
+ def interactions(self) -> pd.DataFrame:
+ return pd.DataFrame(
+ {
+ "user_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5],
+ "item_id": [10, 20, 30, 10, 20, 30, 10, 20, 30, 10, 20, 30, 10, 20, 30],
+ }
+ )
+
+ def test_no_filtering_needed(self, interactions: pd.DataFrame) -> None:
+ result = k_core(interactions, user_min_interactions=3, item_min_interactions=3)
+ assert len(result) == len(interactions)
+
+ def test_filters_users_below_threshold(self) -> None:
+ df = pd.DataFrame(
+ {
+ "user_id": [1, 1, 1, 2],
+ "item_id": [10, 20, 30, 10],
+ }
+ )
+ result = k_core(df, user_min_interactions=2, item_min_interactions=1)
+ assert set(result["user_id"].unique()) == {1}
+
+ def test_filters_items_below_threshold(self) -> None:
+ df = pd.DataFrame(
+ {
+ "user_id": [1, 1, 2, 2, 3, 3],
+ "item_id": [10, 20, 10, 20, 10, 30],
+ }
+ )
+ result = k_core(df, user_min_interactions=1, item_min_interactions=2)
+ assert 30 not in result["item_id"].values
+
+ def test_iterative_filtering(self) -> None:
+ # User 3 has only item 30, but item 30 only has user 3
+ # Removing item 30 removes user 3 interactions, then user 3 has 0
+ df = pd.DataFrame(
+ {
+ "user_id": [1, 1, 2, 2, 3],
+ "item_id": [10, 20, 10, 20, 30],
+ }
+ )
+ result = k_core(df, user_min_interactions=2, item_min_interactions=2)
+ assert 3 not in result["user_id"].values
+ assert 30 not in result["item_id"].values
+ assert len(result) == 4
+
+ def test_custom_column_names(self) -> None:
+ df = pd.DataFrame(
+ {
+ "uid": [1, 1, 2, 2],
+ "iid": [10, 20, 10, 20],
+ }
+ )
+ result = k_core(df, user_min_interactions=2, item_min_interactions=2, user_col="uid", item_col="iid")
+ assert len(result) == 4
+
+ def test_empty_result(self) -> None:
+ df = pd.DataFrame(
+ {
+ "user_id": [1, 2, 3],
+ "item_id": [10, 20, 30],
+ }
+ )
+ result = k_core(df, user_min_interactions=2, item_min_interactions=2)
+ assert len(result) == 0
diff --git a/tests/semantic/data_handling/test_loo_split.py b/tests/semantic/data_handling/test_loo_split.py
new file mode 100644
index 00000000..4d42415b
--- /dev/null
+++ b/tests/semantic/data_handling/test_loo_split.py
@@ -0,0 +1,89 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pandas as pd
+
+from rectools.semantic.data_handling.loo_split import loo_split
+
+
+class TestLooSplit:
+ def test_basic_split(self) -> None:
+ df = pd.DataFrame(
+ {
+ "user_id": [1, 1, 1, 2, 2, 2],
+ "item_id": [10, 20, 30, 40, 50, 60],
+ "timestamp": [1, 2, 3, 4, 5, 6],
+ }
+ )
+ train, val, test = loo_split(df)
+ # Train: all but last 2 per user
+ assert len(train) == 2 # 1 per user
+ # Val: all but last 1 per user
+ assert len(val) == 4 # 2 per user
+ # Test: all interactions
+ assert len(test) == 6
+
+ def test_users_with_fewer_than_3_interactions_dropped(self) -> None:
+ df = pd.DataFrame(
+ {
+ "user_id": [1, 1, 1, 2, 2],
+ "item_id": [10, 20, 30, 40, 50],
+ "timestamp": [1, 2, 3, 4, 5],
+ }
+ )
+ train, val, test = loo_split(df)
+ # User 2 has only 2 interactions, should be dropped
+ assert 2 not in train["user_id"].values
+ assert 2 not in val["user_id"].values
+ assert 2 not in test["user_id"].values
+
+ def test_sorts_by_timestamp(self) -> None:
+ df = pd.DataFrame(
+ {
+ "user_id": [1, 1, 1],
+ "item_id": [30, 10, 20],
+ "timestamp": [3, 1, 2],
+ }
+ )
+ train, val, test = loo_split(df)
+ # After sorting by timestamp: items are [10, 20, 30]
+ # Train = first item, val = first 2 items, test = all 3
+ assert train["item_id"].tolist() == [10]
+ assert set(val["item_id"].tolist()) == {10, 20}
+ assert set(test["item_id"].tolist()) == {10, 20, 30}
+
+ def test_no_timestamp_column(self) -> None:
+ df = pd.DataFrame(
+ {
+ "user_id": [1, 1, 1],
+ "item_id": [10, 20, 30],
+ }
+ )
+ train, val, test = loo_split(df)
+ assert len(train) == 1
+ assert len(val) == 2
+ assert len(test) == 3
+
+ def test_all_users_dropped(self) -> None:
+ df = pd.DataFrame(
+ {
+ "user_id": [1, 1, 2, 2],
+ "item_id": [10, 20, 30, 40],
+ "timestamp": [1, 2, 3, 4],
+ }
+ )
+ train, val, test = loo_split(df)
+ assert len(train) == 0
+ assert len(val) == 0
+ assert len(test) == 0
diff --git a/tests/semantic/modules/__init__.py b/tests/semantic/modules/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/semantic/modules/test_mlp.py b/tests/semantic/modules/test_mlp.py
new file mode 100644
index 00000000..b7e9a709
--- /dev/null
+++ b/tests/semantic/modules/test_mlp.py
@@ -0,0 +1,59 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+import torch
+from pytorch_lightning import seed_everything
+
+from rectools.semantic.modules.mlp import MLP
+
+
+class TestMLP:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ def test_output_shape(self) -> None:
+ mlp = MLP(input_dim=16, hidden_dims=[8], out_dim=4)
+ x = torch.randn(5, 16)
+ out = mlp(x)
+ assert out.shape == (5, 4)
+
+ def test_no_hidden_dims(self) -> None:
+ mlp = MLP(input_dim=16, hidden_dims=[], out_dim=4)
+ x = torch.randn(5, 16)
+ out = mlp(x)
+ assert out.shape == (5, 4)
+
+ def test_multiple_hidden_dims(self) -> None:
+ mlp = MLP(input_dim=32, hidden_dims=[16, 8], out_dim=4)
+ x = torch.randn(3, 32)
+ out = mlp(x)
+ assert out.shape == (3, 4)
+
+ def test_with_dropout(self) -> None:
+ mlp = MLP(input_dim=16, hidden_dims=[8], out_dim=4, dropout=0.5)
+ x = torch.randn(5, 16)
+ out = mlp(x)
+ assert out.shape == (5, 4)
+
+ def test_with_normalize(self) -> None:
+ mlp = MLP(input_dim=16, hidden_dims=[8], out_dim=4, normalize=True)
+ x = torch.randn(5, 16)
+ out = mlp(x)
+ assert out.shape == (5, 4)
+
+ def test_invalid_input_dim_raises(self) -> None:
+ mlp = MLP(input_dim=16, hidden_dims=[8], out_dim=4)
+ with pytest.raises(AssertionError, match="Invalid input dim"):
+ mlp(torch.randn(5, 32))
diff --git a/tests/semantic/modules/test_transformer_blocks.py b/tests/semantic/modules/test_transformer_blocks.py
new file mode 100644
index 00000000..878e7eef
--- /dev/null
+++ b/tests/semantic/modules/test_transformer_blocks.py
@@ -0,0 +1,134 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+import torch
+from pytorch_lightning import seed_everything
+
+from rectools.semantic.modules.transformer_blocks import (
+ T5Attention,
+ T5DecoderLayer,
+ T5EncoderLayer,
+ T5RMSNorm,
+ T5RelativePositionBias,
+ init_feed_forward,
+)
+
+
+class TestT5RMSNorm:
+ def test_output_shape(self) -> None:
+ norm = T5RMSNorm(d_model=16)
+ x = torch.randn(3, 5, 16)
+ out = norm(x)
+ assert out.shape == (3, 5, 16)
+
+ def test_normalized_rms(self) -> None:
+ norm = T5RMSNorm(d_model=16)
+ x = torch.randn(3, 5, 16)
+ out = norm(x)
+ # Output should have roughly unit RMS (before learned scaling)
+ assert torch.isfinite(out).all()
+
+
+class TestT5RelativePositionBias:
+ def test_output_shape(self) -> None:
+ bias = T5RelativePositionBias(num_heads=4, bidirectional=True)
+ out = bias(query_len=5, key_len=5, device=torch.device("cpu"))
+ assert out.shape == (1, 4, 5, 5)
+
+ def test_causal_output_shape(self) -> None:
+ bias = T5RelativePositionBias(num_heads=2, bidirectional=False)
+ out = bias(query_len=3, key_len=3, device=torch.device("cpu"))
+ assert out.shape == (1, 2, 3, 3)
+
+ def test_asymmetric_shape(self) -> None:
+ bias = T5RelativePositionBias(num_heads=2, bidirectional=True)
+ out = bias(query_len=3, key_len=7, device=torch.device("cpu"))
+ assert out.shape == (1, 2, 3, 7)
+
+
+class TestT5Attention:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ def test_self_attention_output_shape(self) -> None:
+ attn = T5Attention(d_model=16, num_heads=2, d_kv=8)
+ x = torch.randn(2, 5, 16)
+ out, _ = attn(x, x, x)
+ assert out.shape == (2, 5, 16)
+
+ def test_cross_attention_output_shape(self) -> None:
+ attn = T5Attention(d_model=16, num_heads=2, d_kv=8, kv_input_dim=32)
+ q = torch.randn(2, 3, 16)
+ kv = torch.randn(2, 7, 32)
+ out, _ = attn(q, kv, kv)
+ assert out.shape == (2, 3, 16)
+
+ def test_with_position_bias(self) -> None:
+ attn = T5Attention(d_model=16, num_heads=2, d_kv=8)
+ x = torch.randn(2, 5, 16)
+ pos_bias = torch.randn(1, 2, 5, 5)
+ out, _ = attn(x, x, x, position_bias=pos_bias)
+ assert out.shape == (2, 5, 16)
+
+
+class TestT5EncoderLayer:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ def test_output_shape(self) -> None:
+ layer = T5EncoderLayer(d_model=16, num_heads=2, d_kv=8, dim_feedforward=32)
+ x = torch.randn(2, 5, 16)
+ out = layer(x)
+ assert out.shape == (2, 5, 16)
+
+ def test_with_padding_mask(self) -> None:
+ layer = T5EncoderLayer(d_model=16, num_heads=2, d_kv=8)
+ x = torch.randn(2, 5, 16)
+ mask = torch.tensor([[False, False, False, True, True], [False, False, True, True, True]])
+ out = layer(x, src_key_padding_mask=mask)
+ assert out.shape == (2, 5, 16)
+
+
+class TestT5DecoderLayer:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ def test_output_shape(self) -> None:
+ layer = T5DecoderLayer(d_model=16, num_heads=2, d_kv=8, dim_feedforward=32)
+ tgt = torch.randn(2, 3, 16)
+ memory = torch.randn(2, 5, 16)
+ out = layer(tgt, memory)
+ assert out.shape == (2, 3, 16)
+
+
+class TestInitFeedForward:
+ def test_relu(self) -> None:
+ ff = init_feed_forward(16, 4, 0.1, "relu")
+ out = ff(torch.randn(2, 5, 16))
+ assert out.shape == (2, 5, 16)
+
+ def test_gelu(self) -> None:
+ ff = init_feed_forward(16, 4, 0.1, "gelu")
+ out = ff(torch.randn(2, 5, 16))
+ assert out.shape == (2, 5, 16)
+
+ def test_swiglu(self) -> None:
+ ff = init_feed_forward(16, 4, 0.1, "swiglu")
+ out = ff(torch.randn(2, 5, 16))
+ assert out.shape == (2, 5, 16)
+
+ def test_unsupported_activation_raises(self) -> None:
+ with pytest.raises(ValueError, match="Unsupported"):
+ init_feed_forward(16, 4, 0.1, "tanh")
diff --git a/tests/semantic/test_metrics.py b/tests/semantic/test_metrics.py
new file mode 100644
index 00000000..8f38729c
--- /dev/null
+++ b/tests/semantic/test_metrics.py
@@ -0,0 +1,75 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from rectools.semantic.metrics import coverage_k, gini_k
+
+
+class TestCoverageK:
+ def test_all_items_covered(self) -> None:
+ all_topk = [[1, 2, 3], [4, 5]]
+ result = coverage_k(all_topk, k=3, num_items=5)
+ assert result == 1.0
+
+ def test_partial_coverage(self) -> None:
+ all_topk = [[1, 2], [1, 3]]
+ result = coverage_k(all_topk, k=2, num_items=5)
+ assert result == pytest.approx(3 / 5)
+
+ def test_k_limits_items(self) -> None:
+ all_topk = [[1, 2, 3, 4, 5]]
+ result = coverage_k(all_topk, k=2, num_items=5)
+ assert result == pytest.approx(2 / 5)
+
+ def test_empty_lists(self) -> None:
+ result = coverage_k([], k=5, num_items=10)
+ assert result == 0.0
+
+ def test_duplicate_items_across_users(self) -> None:
+ all_topk = [[1, 2], [2, 3], [3, 1]]
+ result = coverage_k(all_topk, k=2, num_items=3)
+ assert result == 1.0
+
+
+class TestGiniK:
+ def test_uniform_distribution(self) -> None:
+ all_topk = [[1], [2], [3], [4]]
+ result = gini_k(all_topk, k=1)
+ assert result == pytest.approx(0.0)
+
+ def test_concentrated_distribution(self) -> None:
+ all_topk = [[1], [1], [1], [1]]
+ result = gini_k(all_topk, k=1)
+ assert result == pytest.approx(0.0)
+
+ def test_empty_input(self) -> None:
+ result = gini_k([], k=5)
+ assert result == 0.0
+
+ def test_k_limits_items(self) -> None:
+ all_topk = [[1, 99], [2, 99]]
+ result_k1 = gini_k(all_topk, k=1)
+ result_k2 = gini_k(all_topk, k=2)
+ # With k=1 we only see items 1, 2 (uniform) -> gini=0
+ assert result_k1 == pytest.approx(0.0)
+ # With k=2, item 99 appears twice but 1 and 2 appear once each
+ assert result_k2 > 0.0
+
+ def test_skewed_distribution(self) -> None:
+ # Item 1 recommended 3 times, item 2 once -> should be positive gini
+ all_topk = [[1], [1], [1], [2]]
+ result = gini_k(all_topk, k=1)
+ assert result > 0.0
+ assert result < 1.0
diff --git a/tests/semantic/tiger/__init__.py b/tests/semantic/tiger/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/semantic/tiger/test_lightning.py b/tests/semantic/tiger/test_lightning.py
new file mode 100644
index 00000000..535b6987
--- /dev/null
+++ b/tests/semantic/tiger/test_lightning.py
@@ -0,0 +1,106 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import pytest
+import torch
+from pytorch_lightning import seed_everything
+
+from rectools.semantic.tiger.lightning import TIGERLightning
+from rectools.semantic.tiger.module import TIGERNet
+from rectools.semantic.tokenizer.emb_dataset import EmbDataset
+from rectools.semantic.tokenizer.model import SIDTokenizer
+
+
+@pytest.fixture
+def trained_tokenizer() -> SIDTokenizer:
+ seed_everything(42, workers=True)
+ rng = np.random.RandomState(42)
+ item_ids = list(range(1, 21))
+ embeddings = rng.randn(20, 16).astype(np.float32)
+ dataset = EmbDataset(item_ids, embeddings)
+ tok = SIDTokenizer(
+ codebook_dim=16,
+ hidden_dims=None,
+ codebook_sizes=[4, 4],
+ input_dim=16,
+ device="cpu",
+ quantizer="rkmeans",
+ )
+ tok.init_codebooks(dataset)
+ batch = dataset[:]
+ tok(batch)
+ return tok
+
+
+class TestTIGERLightning: # pylint: disable=redefined-outer-name
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ @pytest.fixture
+ def lightning_model(self, trained_tokenizer: SIDTokenizer) -> TIGERLightning:
+ net = TIGERNet(
+ codebook_sizes=[4, 4],
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ dropout_rate=0.0,
+ max_length=10,
+ )
+ return TIGERLightning(
+ model=net,
+ tokenizer=trained_tokenizer,
+ beam_size=4,
+ top_k=3,
+ num_items=20,
+ )
+
+ def test_training_step(self, lightning_model: TIGERLightning) -> None:
+ batch = {
+ "input_ids": torch.randint(2, 10, (4, 6)),
+ "dec_input": torch.randint(0, 10, (4, 2)),
+ "labels": torch.randint(0, 4, (4, 2)),
+ }
+ loss = lightning_model.training_step(batch, batch_idx=0)
+ assert isinstance(loss, torch.Tensor)
+ assert loss.dim() == 0
+ assert torch.isfinite(loss)
+
+ def test_configure_optimizers_adamw(self, lightning_model: TIGERLightning) -> None:
+ result = lightning_model.configure_optimizers()
+ assert isinstance(result, torch.optim.AdamW)
+
+ def test_configure_optimizers_with_schedule(self, trained_tokenizer: SIDTokenizer) -> None:
+ net = TIGERNet(codebook_sizes=[4, 4], hidden_units=16, num_blocks=1, num_heads=2)
+ model = TIGERLightning(
+ model=net,
+ tokenizer=trained_tokenizer,
+ lr_schedule="cosine",
+ warmup_steps=10,
+ max_iters=100,
+ )
+ result = model.configure_optimizers()
+ assert isinstance(result, dict)
+ assert "optimizer" in result
+ assert "lr_scheduler" in result
+
+ def test_configure_optimizers_unknown_raises(self, trained_tokenizer: SIDTokenizer) -> None:
+ net = TIGERNet(codebook_sizes=[4, 4], hidden_units=16, num_blocks=1, num_heads=2)
+ model = TIGERLightning(
+ model=net,
+ tokenizer=trained_tokenizer,
+ optimizer_name="unknown", # type: ignore[arg-type]
+ )
+ with pytest.raises(ValueError, match="Unknown optimizer"):
+ model.configure_optimizers()
diff --git a/tests/semantic/tiger/test_loss.py b/tests/semantic/tiger/test_loss.py
new file mode 100644
index 00000000..0369332d
--- /dev/null
+++ b/tests/semantic/tiger/test_loss.py
@@ -0,0 +1,55 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from rectools.semantic.tiger.loss import compute_tiger_loss
+
+
+class TestComputeTigerLoss:
+ def test_correct_predictions_low_loss(self) -> None:
+ # Create logits that strongly predict the correct labels
+ batch_size = 4
+ codebook_size = 8
+ sid_len = 3
+ labels = torch.randint(0, codebook_size, (batch_size, sid_len))
+ logits = []
+ for d in range(sid_len):
+ logit = torch.full((batch_size, codebook_size), -10.0)
+ for i in range(batch_size):
+ logit[i, labels[i, d]] = 10.0
+ logits.append(logit)
+ loss = compute_tiger_loss(logits, labels)
+ assert loss.item() < 0.1
+
+ def test_random_predictions_higher_loss(self) -> None:
+ batch_size = 4
+ codebook_size = 8
+ sid_len = 3
+ labels = torch.randint(0, codebook_size, (batch_size, sid_len))
+ logits = [torch.randn(batch_size, codebook_size) for _ in range(sid_len)]
+ loss = compute_tiger_loss(logits, labels)
+ assert loss.item() > 0.1
+
+ def test_output_is_scalar(self) -> None:
+ logits = [torch.randn(2, 4), torch.randn(2, 4)]
+ labels = torch.randint(0, 4, (2, 2))
+ loss = compute_tiger_loss(logits, labels)
+ assert loss.dim() == 0
+
+ def test_ignores_minus_100_labels(self) -> None:
+ logits = [torch.randn(2, 4)]
+ labels = torch.tensor([[-100], [1]])
+ loss = compute_tiger_loss(logits, labels)
+ assert torch.isfinite(loss)
diff --git a/tests/semantic/tiger/test_model.py b/tests/semantic/tiger/test_model.py
new file mode 100644
index 00000000..683497d2
--- /dev/null
+++ b/tests/semantic/tiger/test_model.py
@@ -0,0 +1,245 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+
+import numpy as np
+import pandas as pd
+import pytest
+import torch
+from pytorch_lightning import seed_everything
+
+from rectools.semantic.tiger.model import TIGERModel
+from rectools.semantic.tokenizer.emb_dataset import EmbDataset
+from rectools.semantic.tokenizer.model import SIDTokenizer
+
+
+@pytest.fixture
+def trained_tokenizer() -> SIDTokenizer:
+ seed_everything(42, workers=True)
+ rng = np.random.RandomState(42)
+ item_ids = list(range(1, 51))
+ embeddings = rng.randn(50, 16).astype(np.float32)
+ dataset = EmbDataset(item_ids, embeddings)
+ tok = SIDTokenizer(
+ codebook_dim=16,
+ hidden_dims=None,
+ codebook_sizes=[4, 4],
+ input_dim=16,
+ device="cpu",
+ quantizer="rkmeans",
+ )
+ tok.init_codebooks(dataset)
+ batch = dataset[:]
+ tok(batch)
+ return tok
+
+
+@pytest.fixture
+def interactions() -> pd.DataFrame:
+ rows = []
+ for user in range(1, 11):
+ for t, item in enumerate(range(1, 6)):
+ rows.append([user, item, t])
+ return pd.DataFrame(rows, columns=["user_id", "item_id", "timestamp"])
+
+
+class TestTIGERModel: # pylint: disable=redefined-outer-name
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ def test_init(self, trained_tokenizer: SIDTokenizer) -> None:
+ model = TIGERModel(
+ tokenizer=trained_tokenizer,
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ max_length=10,
+ device="cpu",
+ )
+ assert model.hidden_units == 16
+ assert model.num_blocks == 1
+
+ def test_predict_output_format(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None:
+ model = TIGERModel(
+ tokenizer=trained_tokenizer,
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ max_length=10,
+ device="cpu",
+ beam_size=5,
+ top_k=3,
+ )
+ result = model.predict(interactions, top_k=3)
+ assert isinstance(result, pd.DataFrame)
+ assert set(result.columns) == {"user_id", "item_id", "score", "rank"}
+ # Each user should have at most top_k recommendations
+ for _uid, group in result.groupby("user_id"):
+ assert len(group) <= 3
+
+ def test_save_and_load(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None:
+ model = TIGERModel(
+ tokenizer=trained_tokenizer,
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ max_length=10,
+ device="cpu",
+ )
+ with tempfile.TemporaryDirectory() as tmpdir:
+ model.save(tmpdir)
+ assert os.path.exists(os.path.join(tmpdir, "tokenizer.pt"))
+ assert os.path.exists(os.path.join(tmpdir, "model.pt"))
+ assert os.path.exists(os.path.join(tmpdir, "config.json"))
+
+ loaded = TIGERModel.load(tmpdir, device="cpu")
+ assert loaded.hidden_units == model.hidden_units
+ assert loaded.num_blocks == model.num_blocks
+ assert loaded.num_heads == model.num_heads
+
+ def test_predict_batching(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None:
+ model = TIGERModel(
+ tokenizer=trained_tokenizer,
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ max_length=10,
+ device="cpu",
+ beam_size=5,
+ top_k=3,
+ eval_batch_size=2,
+ )
+ result = model.predict(interactions, top_k=3)
+ assert isinstance(result, pd.DataFrame)
+ assert set(result.columns) == {"user_id", "item_id", "score", "rank"}
+ result_user_ids = set(result["user_id"].unique())
+ input_user_ids = set(interactions["user_id"].unique())
+ assert result_user_ids.issubset(input_user_ids)
+ for _uid, group in result.groupby("user_id"):
+ assert list(group["rank"]) == list(range(1, len(group) + 1))
+ assert len(group) <= 3
+
+ def test_fit_runs(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None:
+ torch.use_deterministic_algorithms(True)
+ model = TIGERModel(
+ tokenizer=trained_tokenizer,
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ max_length=10,
+ max_epochs=1,
+ batch_size=16,
+ eval_batch_size=16,
+ num_workers=0,
+ beam_size=4,
+ top_k=3,
+ device="cpu",
+ )
+ # Split interactions for train/val
+ train_df = interactions[interactions["user_id"] <= 7]
+ val_df = interactions[interactions["user_id"] > 7]
+ model.fit(train_df, val_df)
+ # After fit, model should be able to predict
+ result = model.predict(interactions, top_k=3)
+ assert len(result) > 0
+ torch.use_deterministic_algorithms(False)
+
+ def test_evaluate_uses_unique_item_count_for_coverage(
+ self,
+ trained_tokenizer: SIDTokenizer,
+ monkeypatch: pytest.MonkeyPatch,
+ ) -> None:
+ model = TIGERModel(
+ tokenizer=trained_tokenizer,
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ max_length=10,
+ device="cpu",
+ )
+ interactions = pd.DataFrame(
+ [
+ [1, 10, 1],
+ [1, 20, 2],
+ [1, 30, 3],
+ [2, 10, 1],
+ [2, 30, 2],
+ [2, 50, 3],
+ ],
+ columns=["user_id", "item_id", "timestamp"],
+ )
+ captured = {}
+
+ class FakeTrainer:
+ def test(self, lightning_model, dataloaders):
+ captured["num_items"] = lightning_model.num_items
+ return [{"Coverage@10": 0.0}]
+
+ monkeypatch.setattr("rectools.semantic.tiger.model.pl.Trainer", lambda: FakeTrainer())
+
+ model.evaluate(interactions)
+
+ assert captured["num_items"] == 4
+
+ def test_predict_is_reproducible_with_random_seed(self, trained_tokenizer: SIDTokenizer) -> None:
+ collision_sid = (0, 0)
+ trained_tokenizer.id2sid = {
+ 10: collision_sid,
+ 20: collision_sid,
+ 30: (1, 1),
+ }
+ trained_tokenizer.sid2id.clear()
+
+ interactions = pd.DataFrame(
+ [
+ [1, 10, 1],
+ [1, 30, 2],
+ ],
+ columns=["user_id", "item_id", "timestamp"],
+ )
+
+ first_model = TIGERModel(
+ tokenizer=trained_tokenizer,
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ max_length=10,
+ device="cpu",
+ random_seed=42,
+ )
+ second_model = TIGERModel(
+ tokenizer=trained_tokenizer,
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ max_length=10,
+ device="cpu",
+ random_seed=42,
+ )
+
+ generated_codes = torch.tensor([[[0, 0], [0, 0], [1, 1]]], dtype=torch.long)
+ generated_scores = torch.tensor([[0.9, 0.8, 0.7]], dtype=torch.float32)
+
+ def fake_generate(*args, **kwargs):
+ return generated_codes, generated_scores
+
+ first_model.model.generate = fake_generate # type: ignore[method-assign]
+ second_model.model.generate = fake_generate # type: ignore[method-assign]
+
+ first_result = first_model.predict(interactions, top_k=3)
+ second_result = second_model.predict(interactions, top_k=3)
+
+ assert first_result.equals(second_result)
diff --git a/tests/semantic/tiger/test_module.py b/tests/semantic/tiger/test_module.py
new file mode 100644
index 00000000..cc118c57
--- /dev/null
+++ b/tests/semantic/tiger/test_module.py
@@ -0,0 +1,126 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+import torch
+from pytorch_lightning import seed_everything
+
+from rectools.semantic.tiger.module import TIGERNet
+
+
+class TestTIGERNetStandard:
+ """Tests for TIGERNet with standard nn.Transformer (d_kv=None)."""
+
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ @pytest.fixture
+ def model(self) -> TIGERNet:
+ return TIGERNet(
+ codebook_sizes=[4, 4],
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ dropout_rate=0.0,
+ max_length=10,
+ )
+
+ def test_forward_shapes(self, model: TIGERNet) -> None:
+ batch_size = 3
+ enc_len = 4 # 2 items * 2 codewords
+ dec_len = 2 # sid_len = 2
+ enc_input = torch.randint(2, 10, (batch_size, enc_len))
+ dec_input = torch.randint(0, 10, (batch_size, dec_len))
+ logits = model(enc_input, dec_input)
+ assert len(logits) == 2
+ assert logits[0].shape == (batch_size, 4)
+ assert logits[1].shape == (batch_size, 4)
+
+ def test_generate_greedy_shape(self, model: TIGERNet) -> None:
+ model.eval()
+ enc_input = torch.randint(2, 10, (2, 4))
+ codes = model.generate_greedy(enc_input)
+ assert codes.shape == (2, 2) # (batch_size, sid_len)
+
+ def test_generate_beam_shape(self, model: TIGERNet) -> None:
+ model.eval()
+ enc_input = torch.randint(2, 10, (2, 4))
+ codes, scores = model.generate(enc_input, beam_size=3)
+ assert codes.shape == (2, 3, 2) # (batch, beam, sid_len)
+ assert scores.shape == (2, 3)
+
+ def test_generate_with_padding_mask(self, model: TIGERNet) -> None:
+ model.eval()
+ enc_input = torch.randint(2, 10, (2, 6))
+ enc_input[1, 4:] = TIGERNet.PAD_TOKEN_ID
+ mask = enc_input == TIGERNet.PAD_TOKEN_ID
+ codes, _scores = model.generate(enc_input, enc_padding_mask=mask, beam_size=2)
+ assert codes.shape == (2, 2, 2)
+
+ def test_sid_to_tokens_roundtrip(self, model: TIGERNet) -> None:
+ sid = torch.tensor([[0, 1], [2, 3]])
+ tokens = model.sid_to_tokens(sid)
+ recovered = model.tokens_to_sid(tokens)
+ assert torch.equal(sid, recovered)
+
+ def test_codes_in_valid_range(self, model: TIGERNet) -> None:
+ model.eval()
+ enc_input = torch.randint(2, 10, (3, 4))
+ codes = model.generate_greedy(enc_input)
+ for d in range(model.sid_len):
+ assert (codes[:, d] >= 0).all()
+ assert (codes[:, d] < model.codebook_sizes[d]).all()
+
+
+class TestTIGERNetT5:
+ """Tests for TIGERNet with T5-style layers (d_kv set)."""
+
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ @pytest.fixture
+ def model(self) -> TIGERNet:
+ return TIGERNet(
+ codebook_sizes=[4, 4],
+ hidden_units=16,
+ num_blocks=1,
+ num_heads=2,
+ d_kv=8,
+ dropout_rate=0.0,
+ max_length=10,
+ )
+
+ def test_forward_shapes(self, model: TIGERNet) -> None:
+ enc_input = torch.randint(2, 10, (2, 4))
+ dec_input = torch.randint(0, 10, (2, 2))
+ logits = model(enc_input, dec_input)
+ assert len(logits) == 2
+ assert logits[0].shape == (2, 4)
+
+ def test_generate_greedy(self, model: TIGERNet) -> None:
+ model.eval()
+ enc_input = torch.randint(2, 10, (2, 4))
+ codes = model.generate_greedy(enc_input)
+ assert codes.shape == (2, 2)
+
+ def test_generate_beam(self, model: TIGERNet) -> None:
+ model.eval()
+ enc_input = torch.randint(2, 10, (2, 4))
+ codes, scores = model.generate(enc_input, beam_size=3)
+ assert codes.shape == (2, 3, 2)
+ assert scores.shape == (2, 3)
+
+ def test_no_positional_embeddings(self, model: TIGERNet) -> None:
+ assert model.enc_pos_emb is None
+ assert model.dec_pos_emb is None
diff --git a/tests/semantic/tokenizer/__init__.py b/tests/semantic/tokenizer/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/semantic/tokenizer/test_emb_dataset.py b/tests/semantic/tokenizer/test_emb_dataset.py
new file mode 100644
index 00000000..5e6a196b
--- /dev/null
+++ b/tests/semantic/tokenizer/test_emb_dataset.py
@@ -0,0 +1,50 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import pytest
+
+from rectools.semantic.tokenizer.emb_dataset import EmbDataset
+
+
+class TestEmbDataset:
+ @pytest.fixture
+ def dataset(self) -> EmbDataset:
+ item_ids = [10, 20, 30, 40]
+ embeddings = np.random.RandomState(42).randn(4, 8).astype(np.float32)
+ return EmbDataset(item_ids, embeddings)
+
+ def test_len(self, dataset: EmbDataset) -> None:
+ assert len(dataset) == 4
+
+ def test_dim(self, dataset: EmbDataset) -> None:
+ assert dataset.dim == 8
+
+ def test_getitem_int(self, dataset: EmbDataset) -> None:
+ item = dataset[0]
+ assert item["item_id"] == 10
+ assert item["embed"].shape == (8,)
+
+ def test_getitem_slice(self, dataset: EmbDataset) -> None:
+ batch = dataset[0:2]
+ assert batch["item_id"] == [10, 20]
+ assert batch["embed"].shape == (2, 8)
+
+ def test_getitem_full_slice(self, dataset: EmbDataset) -> None:
+ batch = dataset[:]
+ assert len(batch["item_id"]) == 4
+
+ def test_mismatched_lengths_raises(self) -> None:
+ with pytest.raises(ValueError, match="item_ids length"):
+ EmbDataset([1, 2, 3], np.zeros((4, 8)))
diff --git a/tests/semantic/tokenizer/test_loss.py b/tests/semantic/tokenizer/test_loss.py
new file mode 100644
index 00000000..eb506d9f
--- /dev/null
+++ b/tests/semantic/tokenizer/test_loss.py
@@ -0,0 +1,47 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from rectools.semantic.tokenizer.loss import CodebookLoss
+
+
+class TestCodebookLoss:
+ def test_zero_when_equal(self) -> None:
+ loss_fn = CodebookLoss()
+ x = torch.randn(4, 8)
+ result = loss_fn(x, x.clone())
+ assert result.item() == 0.0
+
+ def test_positive_when_different(self) -> None:
+ loss_fn = CodebookLoss()
+ y_true = torch.randn(4, 8)
+ y_pred = torch.randn(4, 8)
+ result = loss_fn(y_true, y_pred)
+ assert result.item() > 0.0
+
+ def test_custom_coefficients(self) -> None:
+ loss_default = CodebookLoss(mu=1.0, beta=0.25)
+ loss_custom = CodebookLoss(mu=2.0, beta=0.5)
+ y_true = torch.randn(4, 8)
+ y_pred = torch.randn(4, 8)
+ r_default = loss_default(y_true, y_pred)
+ r_custom = loss_custom(y_true, y_pred)
+ # Custom has doubled coefficients, so loss should roughly double
+ assert r_custom.item() > r_default.item()
+
+ def test_output_is_scalar(self) -> None:
+ loss_fn = CodebookLoss()
+ result = loss_fn(torch.randn(4, 8), torch.randn(4, 8))
+ assert result.dim() == 0
diff --git a/tests/semantic/tokenizer/test_model.py b/tests/semantic/tokenizer/test_model.py
new file mode 100644
index 00000000..adb27897
--- /dev/null
+++ b/tests/semantic/tokenizer/test_model.py
@@ -0,0 +1,229 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import tempfile
+
+import numpy as np
+import pytest
+import torch
+from pytorch_lightning import seed_everything
+
+from rectools.semantic.tokenizer.emb_dataset import EmbDataset
+from rectools.semantic.tokenizer.model import SIDTokenizer
+
+
+class TestSIDTokenizerRKmeans:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ @pytest.fixture
+ def dataset(self) -> EmbDataset:
+ rng = np.random.RandomState(42)
+ item_ids = list(range(100))
+ embeddings = rng.randn(100, 16).astype(np.float32)
+ return EmbDataset(item_ids, embeddings)
+
+ @pytest.fixture
+ def tokenizer(self, dataset: EmbDataset) -> SIDTokenizer:
+ tok = SIDTokenizer(
+ codebook_dim=16,
+ hidden_dims=None,
+ codebook_sizes=[8, 8],
+ input_dim=16,
+ device="cpu",
+ quantizer="rkmeans",
+ )
+ tok.init_codebooks(dataset)
+ # Run a forward pass to populate id2sid
+ batch = dataset[:]
+ tok(batch)
+ return tok
+
+ def test_init(self) -> None:
+ tok = SIDTokenizer(
+ codebook_dim=16,
+ hidden_dims=None,
+ codebook_sizes=[8, 8],
+ input_dim=16,
+ quantizer="rkmeans",
+ )
+ assert tok.quantizer_name == "rkmeans"
+ assert tok.codebook_sizes == [8, 8]
+
+ def test_forward_returns_loss(self, dataset: EmbDataset) -> None:
+ tok = SIDTokenizer(
+ codebook_dim=16,
+ hidden_dims=None,
+ codebook_sizes=[8, 8],
+ input_dim=16,
+ device="cpu",
+ quantizer="rkmeans",
+ )
+ tok.init_codebooks(dataset)
+ batch = dataset[0:10]
+ loss = tok(batch)
+ assert isinstance(loss, torch.Tensor)
+ assert loss.dim() == 0
+ assert loss.item() >= 0
+
+ def test_forward_populates_id2sid(self, dataset: EmbDataset) -> None:
+ tok = SIDTokenizer(
+ codebook_dim=16,
+ hidden_dims=None,
+ codebook_sizes=[8, 8],
+ input_dim=16,
+ device="cpu",
+ quantizer="rkmeans",
+ )
+ tok.init_codebooks(dataset)
+ batch = dataset[0:10]
+ tok(batch)
+ assert len(tok.id2sid) == 10
+
+ def test_tokenize_single(self, tokenizer: SIDTokenizer) -> None:
+ sid = tokenizer.tokenize(0)
+ assert isinstance(sid, tuple)
+ assert len(sid) == 2
+
+ def test_tokenize_multiple(self, tokenizer: SIDTokenizer) -> None:
+ sids = tokenizer.tokenize([0, 1, 2])
+ assert isinstance(sids, list)
+ assert len(sids) == 3
+ for sid in sids:
+ assert isinstance(sid, tuple)
+ assert len(sid) == 2
+
+ def test_tokenize_unknown_item_raises(self, tokenizer: SIDTokenizer) -> None:
+ with pytest.raises(ValueError, match="out of vocabulary"):
+ tokenizer.tokenize(9999)
+
+ def test_decode_single(self, tokenizer: SIDTokenizer) -> None:
+ sid = tokenizer.tokenize(5)
+ item_id = tokenizer.decode(sid)
+ # The decoded item should exist in the vocabulary
+ assert item_id is not None
+
+ def test_decode_multiple(self, tokenizer: SIDTokenizer) -> None:
+ sids = tokenizer.tokenize([0, 1, 2])
+ decoded = tokenizer.decode(sids)
+ assert isinstance(decoded, list)
+ assert len(decoded) == 3
+
+ def test_decode_unknown_sid(self, tokenizer: SIDTokenizer) -> None:
+ unknown_sid = (999, 999)
+ result = tokenizer.decode(unknown_sid)
+ assert result is None
+
+ def test_decode_with_default_value(self, tokenizer: SIDTokenizer) -> None:
+ unknown_sid = (999, 999)
+ result = tokenizer.decode(unknown_sid, default_value=-1)
+ assert result == -1
+
+ def test_decode_collision_is_reproducible_with_rng(self, tokenizer: SIDTokenizer) -> None:
+ tokenizer.id2sid = {
+ 10: (1, 1),
+ 20: (1, 1),
+ 30: (2, 2),
+ }
+ tokenizer.sid2id.clear()
+
+ first_rng = np.random.RandomState(42)
+ second_rng = np.random.RandomState(42)
+
+ first = tokenizer.decode([(1, 1), (1, 1), (1, 1)], rng=first_rng)
+ second = tokenizer.decode([(1, 1), (1, 1), (1, 1)], rng=second_rng)
+
+ assert first == second
+
+ def test_len(self, tokenizer: SIDTokenizer) -> None:
+ assert len(tokenizer) > 0
+ assert len(tokenizer) <= 100
+
+ def test_extend(self, tokenizer: SIDTokenizer) -> None:
+ rng = np.random.RandomState(99)
+ new_ids = list(range(100, 120))
+ new_embeddings = rng.randn(20, 16).astype(np.float32)
+ new_dataset = EmbDataset(new_ids, new_embeddings)
+ tokenizer.extend(new_dataset)
+ # New items should be in id2sid
+ for item_id in new_ids:
+ assert item_id in tokenizer.id2sid
+
+ def test_extend_skips_existing(self, tokenizer: SIDTokenizer) -> None:
+ old_sid = tokenizer.id2sid[0]
+ rng = np.random.RandomState(99)
+ # Include existing item 0
+ ids = [0, 200]
+ embs = rng.randn(2, 16).astype(np.float32)
+ ds = EmbDataset(ids, embs)
+ tokenizer.extend(ds)
+ # Item 0 should still have the same SID
+ assert tokenizer.id2sid[0] == old_sid
+ assert 200 in tokenizer.id2sid
+
+ def test_save_and_load(self, tokenizer: SIDTokenizer) -> None:
+ with tempfile.TemporaryDirectory() as tmpdir:
+ path = os.path.join(tmpdir, "tokenizer.pt")
+ tokenizer.save(path)
+ loaded = SIDTokenizer.load(path, map_location="cpu")
+
+ assert loaded.codebook_sizes == tokenizer.codebook_sizes
+ assert loaded.input_dim == tokenizer.input_dim
+ assert loaded.codebook_dim == tokenizer.codebook_dim
+ assert set(loaded.id2sid.keys()) == set(tokenizer.id2sid.keys())
+
+ # Check that tokenization produces same results
+ sid_orig = tokenizer.tokenize(5)
+ sid_loaded = loaded.tokenize(5)
+ assert sid_orig == sid_loaded
+
+
+class TestSIDTokenizerRQVAE:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ @pytest.fixture
+ def dataset(self) -> EmbDataset:
+ rng = np.random.RandomState(42)
+ item_ids = list(range(100))
+ embeddings = rng.randn(100, 16).astype(np.float32)
+ return EmbDataset(item_ids, embeddings)
+
+ def test_init_rqvae(self) -> None:
+ tok = SIDTokenizer(
+ codebook_dim=8,
+ hidden_dims=[8],
+ codebook_sizes=[4, 4],
+ input_dim=16,
+ quantizer="rqvae",
+ )
+ assert tok.quantizer_name == "rqvae"
+
+ def test_forward_rqvae(self, dataset: EmbDataset) -> None:
+ tok = SIDTokenizer(
+ codebook_dim=8,
+ hidden_dims=[8],
+ codebook_sizes=[4, 4],
+ input_dim=16,
+ device="cpu",
+ quantizer="rqvae",
+ adapter_proj_dim=8,
+ )
+ tok.init_codebooks(dataset)
+ batch = dataset[0:10]
+ loss = tok(batch)
+ assert isinstance(loss, torch.Tensor)
+ assert torch.isfinite(loss)
+ assert len(tok.id2sid) == 10
diff --git a/tests/semantic/tokenizer/test_rkmeans.py b/tests/semantic/tokenizer/test_rkmeans.py
new file mode 100644
index 00000000..5fc11358
--- /dev/null
+++ b/tests/semantic/tokenizer/test_rkmeans.py
@@ -0,0 +1,79 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from pytorch_lightning import seed_everything
+
+from rectools.semantic.tokenizer.rkmeans import Codebook, RKmeans
+
+
+class TestCodebook:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ def test_output_shapes(self) -> None:
+ cb = Codebook(code_dim=8, n_codes=16)
+ inputs = torch.randn(4, 8)
+ codes, quantized, residuals = cb(inputs)
+ assert codes.shape == (4,)
+ assert quantized.shape == (4, 8)
+ assert residuals.shape == (4, 8)
+
+ def test_codes_in_range(self) -> None:
+ cb = Codebook(code_dim=8, n_codes=16)
+ inputs = torch.randn(10, 8)
+ codes, _, _ = cb(inputs)
+ assert (codes >= 0).all()
+ assert (codes < 16).all()
+
+ def test_init_from_centroids(self) -> None:
+ cb = Codebook(code_dim=4, n_codes=3)
+ centroids = torch.randn(3, 4)
+ cb.init_from_centroids(centroids)
+ assert cb.kmeans_initted_
+ assert torch.allclose(cb.code_embs.weight.data, centroids)
+
+
+class TestRKmeans:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ def test_forward_shapes(self) -> None:
+ model = RKmeans(input_dim=16, codebook_sizes=[8, 8])
+ inputs = torch.randn(10, 16)
+ sem_ids, loss = model(inputs)
+ assert len(sem_ids) == 10
+ assert all(len(sid) == 2 for sid in sem_ids)
+ assert loss.dim() == 0
+
+ def test_loss_is_positive(self) -> None:
+ model = RKmeans(input_dim=16, codebook_sizes=[8, 8])
+ inputs = torch.randn(10, 16)
+ _, loss = model(inputs)
+ assert loss.item() > 0.0
+
+ def test_init_codebooks(self) -> None:
+ model = RKmeans(input_dim=16, codebook_sizes=[4, 4])
+ data = torch.randn(100, 16)
+ model.init_codebooks(data)
+ for layer in model.codebooks:
+ assert layer.kmeans_initted_
+
+ def test_sem_ids_are_tuples(self) -> None:
+ model = RKmeans(input_dim=8, codebook_sizes=[4, 4, 4])
+ inputs = torch.randn(5, 8)
+ sem_ids, _ = model(inputs)
+ for sid in sem_ids:
+ assert isinstance(sid, tuple)
+ assert len(sid) == 3
diff --git a/tests/semantic/tokenizer/test_rqvae.py b/tests/semantic/tokenizer/test_rqvae.py
new file mode 100644
index 00000000..a1b4223a
--- /dev/null
+++ b/tests/semantic/tokenizer/test_rqvae.py
@@ -0,0 +1,106 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from pytorch_lightning import seed_everything
+
+from rectools.semantic.tokenizer.rqvae import RQVAE, WhiteningAdapter
+
+
+class TestWhiteningAdapter:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ def test_output_shape(self) -> None:
+ adapter = WhiteningAdapter(adapter_type="identity", emb_dim=16, proj_dim=8)
+ inputs = torch.randn(5, 16)
+ adapter.init_from_embeds(inputs)
+ output = adapter(inputs)
+ assert output.shape == (5, 8)
+
+ def test_ffn_adapter(self) -> None:
+ adapter = WhiteningAdapter(adapter_type="ffn", emb_dim=16, proj_dim=8, hidden_units=4)
+ inputs = torch.randn(10, 16)
+ adapter.init_from_embeds(inputs)
+ output = adapter(inputs)
+ assert output.shape == (10, 8)
+
+ def test_freeze_adapter(self) -> None:
+ adapter = WhiteningAdapter(emb_dim=16, proj_dim=8)
+ inputs = torch.randn(10, 16)
+ adapter.init_from_embeds(inputs)
+ assert not adapter.weight.requires_grad
+ assert not adapter.bias.requires_grad
+
+
+class TestRQVAE:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ def test_forward_shapes(self) -> None:
+ model = RQVAE(
+ input_dim=16,
+ codebook_dim=8,
+ hidden_dims=[8],
+ codebook_sizes=[4, 4],
+ adapter_proj_dim=8,
+ )
+ data = torch.randn(20, 16)
+ model.init_codebooks(data)
+ sem_ids, loss = model(data)
+ assert len(sem_ids) == 20
+ assert all(len(sid) == 2 for sid in sem_ids)
+ assert loss.dim() == 0
+
+ def test_loss_is_finite(self) -> None:
+ model = RQVAE(
+ input_dim=16,
+ codebook_dim=8,
+ hidden_dims=[8],
+ codebook_sizes=[4, 4],
+ adapter_proj_dim=8,
+ )
+ data = torch.randn(20, 16)
+ model.init_codebooks(data)
+ _, loss = model(data)
+ assert torch.isfinite(loss)
+
+ def test_init_codebooks_sets_flag(self) -> None:
+ model = RQVAE(
+ input_dim=16,
+ codebook_dim=8,
+ hidden_dims=[8],
+ codebook_sizes=[4, 4],
+ adapter_proj_dim=8,
+ )
+ data = torch.randn(50, 16)
+ model.init_codebooks(data)
+ for layer in model.codebooks:
+ assert layer.kmeans_initted_
+
+ def test_sem_ids_are_tuples_of_ints(self) -> None:
+ model = RQVAE(
+ input_dim=16,
+ codebook_dim=8,
+ hidden_dims=[8],
+ codebook_sizes=[4, 4, 4],
+ adapter_proj_dim=8,
+ )
+ data = torch.randn(20, 16)
+ model.init_codebooks(data)
+ sem_ids, _ = model(data)
+ for sid in sem_ids:
+ assert isinstance(sid, tuple)
+ assert all(isinstance(c, int) for c in sid)
+ assert len(sid) == 3
diff --git a/tests/semantic/tokenizer/test_train.py b/tests/semantic/tokenizer/test_train.py
new file mode 100644
index 00000000..cfdee28f
--- /dev/null
+++ b/tests/semantic/tokenizer/test_train.py
@@ -0,0 +1,72 @@
+# Copyright 2025 MTS (Mobile Telesystems)
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import tempfile
+
+import numpy as np
+from pytorch_lightning import seed_everything
+
+from rectools.semantic.tokenizer.model import SIDTokenizer
+
+
+class TestSIDTokenizerFit:
+ def setup_method(self) -> None:
+ seed_everything(42, workers=True)
+
+ def test_fit_rkmeans(self) -> None:
+ rng = np.random.RandomState(42)
+ item_ids = list(range(100))
+ embeddings = rng.randn(100, 16).astype(np.float32)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tok = SIDTokenizer(
+ input_dim=16,
+ codebook_sizes=[4, 4],
+ quantizer="rkmeans",
+ device="cpu",
+ )
+ tok.fit(
+ item_ids=item_ids,
+ embeddings=embeddings,
+ max_epochs=3,
+ patience=2,
+ batch_size=50,
+ save_dir=tmpdir,
+ )
+ assert len(tok.id2sid) > 0
+
+ def test_fit_rqvae(self) -> None:
+ rng = np.random.RandomState(42)
+ item_ids = list(range(100))
+ embeddings = rng.randn(100, 16).astype(np.float32)
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tok = SIDTokenizer(
+ input_dim=16,
+ codebook_sizes=[4, 4],
+ codebook_dim=8,
+ hidden_dims=[8],
+ quantizer="rqvae",
+ adapter_proj_dim=8,
+ device="cpu",
+ )
+ tok.fit(
+ item_ids=item_ids,
+ embeddings=embeddings,
+ max_epochs=3,
+ patience=2,
+ batch_size=50,
+ save_dir=tmpdir,
+ )
+ assert len(tok.id2sid) > 0