diff --git a/.pylintrc b/.pylintrc index 4e65a5b3..d3f932dc 100644 --- a/.pylintrc +++ b/.pylintrc @@ -12,7 +12,7 @@ ignore=CVS # Add files or directories matching the regex patterns to the blacklist. # The regex matches against base names, not paths. -ignore-patterns= +ignore-patterns=test_transformer_blocks # Python code to execute, usually for sys.path manipulation # such as pygtk.require(). diff --git a/CHANGELOG.md b/CHANGELOG.md index 070e44c9..3683c370 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added +- `rectools.semantic` package with TIGER generative recommender components, including model, Lightning integration, loss, tokenizer modules, semantic metrics, and data handling utilities +- Semantic building blocks for neural recommender workflows, including transformer blocks, MLP modules, residual k-means, and RQ-VAE tokenizer implementations +- TIGER tutorial notebook, SASRec vs TIGER benchmark script, and semantic test coverage across data handling, modules, tokenizer, metrics, and TIGER components + ## [Unreleased] diff --git a/benchmark/compare_sasrec_tiger.py b/benchmark/compare_sasrec_tiger.py new file mode 100644 index 00000000..0045503e --- /dev/null +++ b/benchmark/compare_sasrec_tiger.py @@ -0,0 +1,810 @@ +"""Compare RecTools SASRecModel and semantic TIGER on ML-20M. + +Both models are trained on MovieLens 20M data and the benchmark writes a +Markdown report to ``benchmark/comparison_sasrec_vs_tiger.md``. + +Usage: +```bash +python3 -m benchmark.compare_sasrec_tiger +``` + +The script downloads ML-20M automatically if it is not present locally. +Movie metadata is taken from ``movies.csv`` and embedded with +``Qwen/Qwen3-Embedding-0.6B`` for the TIGER tokenizer. +""" + +# pylint: disable=too-many-locals,too-many-statements,import-outside-toplevel +# pylint: disable=import-error,unsubscriptable-object,protected-access,cyclic-import + +import argparse +import gc +import shutil +import time +import zipfile +from datetime import datetime +from pathlib import Path +from typing import Optional +from urllib.request import urlretrieve + +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from sentence_transformers import SentenceTransformer +from tqdm.auto import trange + +from rectools import Columns +from rectools.dataset import Dataset +from rectools.metrics import MRR, NDCG, CatalogCoverage, HitRate +from rectools.models import SASRecModel +from rectools.models.nn.item_net import IdEmbeddingsItemNet +from rectools.models.nn.transformers.utils import leave_one_out_mask +from rectools.semantic.data_handling import k_core, loo_split +from rectools.semantic.metrics import gini_k +from rectools.semantic.tiger import TIGERModel +from rectools.semantic.tokenizer import SIDTokenizer + +BENCHMARK_DIR = Path(__file__).resolve().parent +DEFAULT_WORKDIR = BENCHMARK_DIR / "data" / "ml-20m" +DEFAULT_REPORT_PATH = BENCHMARK_DIR / "comparison_sasrec_vs_tiger.md" +ML20M_URL = "https://files.grouplens.org/datasets/movielens/ml-20m.zip" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark RecTools SASRec against semantic TIGER on MovieLens 20M." + ) + parser.add_argument( + "--workdir", + type=Path, + default=DEFAULT_WORKDIR, + help="Directory for downloaded data and cached artifacts.", + ) + parser.add_argument( + "--report-path", + type=Path, + default=DEFAULT_REPORT_PATH, + help="Path to the Markdown report produced by the benchmark.", + ) + parser.add_argument( + "--device", + type=str, + default=None, + help="Torch device override. Defaults to cuda -> mps -> cpu.", + ) + parser.add_argument( + "--top-k-main", + type=int, + default=10, + help="K for NDCG, MRR, coverage, and gini.", + ) + parser.add_argument( + "--max-length", type=int, default=200, help="Maximum context length." + ) + parser.add_argument("--top-k-hit", type=int, default=10, help="K for hit rate.") + parser.add_argument("--seed", type=int, default=42, help="Random seed.") + parser.add_argument( + "--limit-users", + type=int, + default=None, + help="Optional cap on users after preprocessing.", + ) + parser.add_argument( + "--min-rating", + type=float, + default=-1.0, + help="Optional minimum rating filter. Use -1 to keep all ratings.", + ) + parser.add_argument( + "--min-item-interactions", + type=int, + default=5, + help="Optional minimum number of interactions per item.", + ) + parser.add_argument( + "--min-user-interactions", + type=int, + default=2, + help="Optional minimum number of interactions per user.", + ) + parser.add_argument( + "--embedding-model", + type=str, + default="Qwen/Qwen3-Embedding-0.6B", + help="SentenceTransformer model used to embed movie metadata.", + ) + parser.add_argument( + "--tokenizer-max-epochs", + type=int, + default=1000, + help="Tokenizer training epochs.", + ) + parser.add_argument( + "--tokenizer-batch-size", + type=int, + default=256, + help="Batch size for metadata embedding and tokenizer training.", + ) + parser.add_argument( + "--tokenizer-patience", + type=int, + default=100, + help="Tokenizer early stopping patience.", + ) + parser.add_argument( + "--sasrec-epochs", type=int, default=10, help="Maximum SASRec epochs." + ) + parser.add_argument( + "--sasrec-patience", + type=int, + default=10, + help="SASRec early stopping patience.", + ) + parser.add_argument( + "--sasrec-batch-size", type=int, default=256, help="SASRec batch size." + ) + parser.add_argument( + "--sasrec-num-workers", type=int, default=4, help="SASRec dataloader workers." + ) + parser.add_argument( + "--tiger-epochs", type=int, default=10, help="Maximum TIGER epochs." + ) + parser.add_argument( + "--tiger-patience", type=int, default=10, help="TIGER early stopping patience." + ) + parser.add_argument( + "--tiger-batch-size", type=int, default=256, help="TIGER train batch size." + ) + parser.add_argument( + "--tiger-eval-batch-size", type=int, default=256, help="TIGER eval batch size." + ) + parser.add_argument( + "--tiger-num-workers", type=int, default=4, help="TIGER dataloader workers." + ) + return parser.parse_args() + + +def resolve_device(device: Optional[str]) -> str: + if device is not None: + return device + if torch.cuda.is_available(): + return "cuda" + if ( + getattr(torch.backends, "mps", None) is not None + and torch.backends.mps.is_available() + ): + return "mps" + return "cpu" + + +def describe_device(device: str) -> str: + if device.startswith("cuda") and torch.cuda.is_available(): + gpu_index = torch.device(device).index or 0 + return f"{device} ({torch.cuda.get_device_name(gpu_index)})" + if device == "mps": + return "mps (Apple Metal)" + return "cpu" + + +def maybe_download(url: str, dst: Path) -> None: + if dst.exists(): + return + dst.parent.mkdir(parents=True, exist_ok=True) + print(f"Downloading {url} -> {dst}") + urlretrieve(url, dst) # nosec B310 + + +def download_ml20m(workdir: Path) -> tuple[Path, Path]: + ratings_path = workdir / "ratings.csv" + movies_path = workdir / "movies.csv" + if ratings_path.exists() and movies_path.exists(): + return ratings_path, movies_path + + archive_path = workdir / "ml-20m.zip" + maybe_download(ML20M_URL, archive_path) + + print(f"Extracting {archive_path}") + with zipfile.ZipFile(archive_path) as archive: + for member in archive.namelist(): + basename = Path(member).name + if basename not in { + "ratings.csv", + "movies.csv", + "README.txt", + "links.csv", + "tags.csv", + }: + continue + target = workdir / basename + if target.exists(): + continue + with archive.open(member) as src, open(target, "wb") as dst_stream: + shutil.copyfileobj(src, dst_stream) + + return ratings_path, movies_path + + +def load_interactions(ratings_path: Path, args: argparse.Namespace) -> pd.DataFrame: + interactions = pd.read_csv(ratings_path) + interactions.columns = ["user_id", "item_id", "rating", "timestamp"] + + if args.min_rating > 0: + interactions = interactions[interactions["rating"] >= args.min_rating] + + use_k_core = args.min_item_interactions > 0 or args.min_user_interactions > 0 + if use_k_core: + print( + f"Performing K-Core filtering (user>={args.min_user_interactions}, item>={args.min_item_interactions})..." + ) + interactions = k_core( + interactions, + user_min_interactions=max(args.min_user_interactions, 0), + item_min_interactions=max(args.min_item_interactions, 0), + ) + + if args.limit_users is not None: + kept_users = interactions["user_id"].drop_duplicates().iloc[: args.limit_users] + interactions = interactions[ + interactions["user_id"].isin(kept_users) + ].reset_index(drop=True) + if use_k_core: + interactions = k_core( + interactions, + user_min_interactions=max(args.min_user_interactions, 0), + item_min_interactions=max(args.min_item_interactions, 0), + ) + + interactions = interactions.sort_values( + ["user_id", "timestamp", "item_id"] + ).reset_index(drop=True) + interactions["timestamp"] = pd.to_datetime(interactions["timestamp"], unit="s") + + return interactions + + +def save_preprocessed_files( + workdir: Path, + interactions: pd.DataFrame, + meta: pd.DataFrame, + train_df: pd.DataFrame, + val_df: pd.DataFrame, + test_df: pd.DataFrame, + test_targets: pd.DataFrame, +) -> dict[str, Path]: + preprocessed_dir = workdir / "preprocessed" + preprocessed_dir.mkdir(parents=True, exist_ok=True) + + paths = { + "interactions": preprocessed_dir / "interactions.csv", + "item_metadata": preprocessed_dir / "item_metadata.csv", + "train": preprocessed_dir / "train.csv", + "val": preprocessed_dir / "val.csv", + "test": preprocessed_dir / "test.csv", + "test_targets": preprocessed_dir / "test_targets.csv", + } + interactions.to_csv(paths["interactions"], index=False) + meta.to_csv(paths["item_metadata"], index=False) + train_df.to_csv(paths["train"], index=False) + val_df.to_csv(paths["val"], index=False) + test_df.to_csv(paths["test"], index=False) + test_targets.to_csv(paths["test_targets"], index=False) + + return paths + + +def build_meta(movies_path: Path, interactions: pd.DataFrame) -> pd.DataFrame: + movies = pd.read_csv(movies_path) + movies = movies.rename( + columns={"movieId": "item_id", "title": "title", "genres": "genres"} + ) + item_ids = set(interactions["item_id"].unique()) + + meta = movies[movies["item_id"].isin(item_ids)].copy() + meta["title"] = meta["title"].fillna("Unknown").astype(str) + meta["genres"] = meta["genres"].fillna("Unknown").astype(str) + meta["text"] = meta["title"] + " Genres: " + meta["genres"] + meta = meta[["item_id", "text"]].drop_duplicates("item_id") + + missing_items = sorted(item_ids - set(meta["item_id"])) + if missing_items: + missing_meta = pd.DataFrame( + { + "item_id": missing_items, + "text": ["Unknown title. Genres: Unknown"] * len(missing_items), + } + ) + meta = pd.concat([meta, missing_meta], ignore_index=True) + + return meta.sort_values("item_id").reset_index(drop=True) + + +def map_ids( + interactions: pd.DataFrame, meta: pd.DataFrame +) -> tuple[pd.DataFrame, pd.DataFrame]: + user_codes = pd.Index(interactions["user_id"].unique()) + item_codes = pd.Index(interactions["item_id"].unique()) + + user_map = pd.Series(np.arange(len(user_codes), dtype=np.int64), index=user_codes) + item_map = pd.Series(np.arange(len(item_codes), dtype=np.int64), index=item_codes) + + mapped_interactions = interactions.copy() + mapped_interactions["user_id"] = ( + mapped_interactions["user_id"].map(user_map).astype(np.int64) + ) + mapped_interactions["item_id"] = ( + mapped_interactions["item_id"].map(item_map).astype(np.int64) + ) + + mapped_meta = meta[meta["item_id"].isin(item_map.index)].copy() + mapped_meta["item_id"] = mapped_meta["item_id"].map(item_map).astype(np.int64) + mapped_meta = mapped_meta.sort_values("item_id").reset_index(drop=True) + + return mapped_interactions, mapped_meta + + +def encode_texts( + texts: list[str], model_name: str, device: str, batch_size: int +) -> np.ndarray: + embeds = [] + model = SentenceTransformer(model_name, device=device).eval() + with torch.no_grad(): + for start in trange(0, len(texts), batch_size, desc="Embedding movie metadata"): + end = min(start + batch_size, len(texts)) + embeds.append(model.encode(texts[start:end], show_progress_bar=False)) + return np.vstack(embeds) + + +def make_sasrec_dataset(val_df: pd.DataFrame) -> Dataset: + sasrec_df = val_df.rename(columns={"timestamp": Columns.Datetime}).copy() + sasrec_df[Columns.Weight] = 1.0 + return Dataset.construct( + interactions_df=sasrec_df[ + [Columns.User, Columns.Item, Columns.Weight, Columns.Datetime] + ] + ) + + +def get_last_item_targets(test_df: pd.DataFrame) -> pd.DataFrame: + targets = ( + test_df.sort_values(["user_id", "timestamp"]) + .groupby("user_id", sort=False) + .tail(1) + .copy() + ) + targets["weight"] = 1.0 + return targets[["user_id", "item_id", "weight"]] + + +def reco_to_topk_lists(reco: pd.DataFrame, k: int) -> list[list[int]]: + ordered = reco[reco[Columns.Rank] <= k].sort_values([Columns.User, Columns.Rank]) + grouped = ordered.groupby(Columns.User, sort=False)[Columns.Item].agg(list) + return grouped.tolist() + + +def evaluate_recommendations( + reco: pd.DataFrame, + targets: pd.DataFrame, + catalog_size: int, + top_k_main: int, + top_k_hit: int, +) -> dict[str, float]: + return { + f"NDCG@{top_k_main}": NDCG(k=top_k_main).calc(reco, targets), + f"HR@{top_k_hit}": HitRate(k=top_k_hit).calc(reco, targets), + f"MRR@{top_k_main}": MRR(k=top_k_main).calc(reco, targets), + f"Coverage@{top_k_main}": CatalogCoverage(k=top_k_main, normalize=True).calc( + reco, catalog=list(range(catalog_size)) + ), + f"Gini@{top_k_main}": gini_k(reco_to_topk_lists(reco, top_k_main), top_k_main), + } + + +def train_sasrec( + dataset: Dataset, + device: str, + max_epochs: int, + patience: int, + batch_size: int, + num_workers: int, + max_length: int, + min_user_interactions: int, +) -> SASRecModel: + accelerator = "gpu" if device.startswith("cuda") else device + trainer_kwargs = { + "accelerator": accelerator, + "devices": 1, + "min_epochs": 1, + "max_epochs": max_epochs, + "callbacks": [ + EarlyStopping( + monitor=SASRecModel.val_loss_name, + patience=patience, + mode="min", + ) + ], + "deterministic": True, + } + if device.startswith("cuda"): + trainer_kwargs["precision"] = "bf16-mixed" + + trainer = Trainer(**trainer_kwargs) + + model = SASRecModel( + n_factors=200, + n_blocks=1, + n_heads=2, + dropout_rate=0.1, + train_min_user_interactions=min_user_interactions, + session_max_len=max_length, + item_net_block_types=(IdEmbeddingsItemNet,), + get_val_mask_func=leave_one_out_mask, + batch_size=batch_size, + dataloader_num_workers=num_workers, + verbose=1, + deterministic=True, + recommend_torch_device=device, + ) + model._trainer = trainer + model.fit(dataset) + return model + + +def cleanup() -> None: + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def benchmark_sasrec( + val_df: pd.DataFrame, + test_targets: pd.DataFrame, + device: str, + args: argparse.Namespace, +) -> tuple[dict[str, float], float, float]: + print("Converting validation history to RecTools dataset for SASRec") + dataset = make_sasrec_dataset(val_df) + + print("Fitting RecTools SASRec") + fit_start = time.perf_counter() + model = train_sasrec( + dataset=dataset, + device=device, + max_epochs=args.sasrec_epochs, + patience=args.sasrec_patience, + batch_size=args.sasrec_batch_size, + num_workers=args.sasrec_num_workers, + max_length=args.max_length, + min_user_interactions=args.min_user_interactions, + ) + train_seconds = time.perf_counter() - fit_start + + users = test_targets["user_id"].to_numpy() + inference_start = time.perf_counter() + reco = model.recommend( + users=users, + dataset=dataset, + k=max(args.top_k_main, args.top_k_hit), + filter_viewed=False, + ) + inference_seconds = time.perf_counter() - inference_start + + metrics = evaluate_recommendations( + reco=reco, + targets=test_targets, + catalog_size=val_df["item_id"].nunique(), + top_k_main=args.top_k_main, + top_k_hit=args.top_k_hit, + ) + return metrics, train_seconds, inference_seconds + + +def benchmark_tiger( + train_df: pd.DataFrame, + val_df: pd.DataFrame, + test_targets: pd.DataFrame, + meta: pd.DataFrame, + device: str, + args: argparse.Namespace, +) -> tuple[dict[str, float], float, float]: + meta = meta.sort_values("item_id").drop_duplicates("item_id") + embeddings = encode_texts( + meta["text"].tolist(), + model_name=args.embedding_model, + device=device, + batch_size=args.tokenizer_batch_size, + ) + tokenizer = SIDTokenizer( + input_dim=embeddings.shape[1], + codebook_sizes=[256, 256, 256], + codebook_dim=32, + hidden_dims=[768, 256, 128, 64], + quantizer="rqvae", + device=device, + ) + tokenizer.fit( + item_ids=meta["item_id"].tolist(), + embeddings=embeddings, + max_epochs=args.tokenizer_max_epochs, + patience=args.tokenizer_patience, + batch_size=args.tokenizer_batch_size, + save_dir=str(args.workdir / "tokenizer_checkpoints"), + ) + + tiger = TIGERModel( + tokenizer=tokenizer, + hidden_units=128, + num_blocks=4, + num_heads=6, + dropout_rate=0.1, + max_length=args.max_length, + lr=1e-3, + lr_schedule="cosine", + warmup_steps=0, + max_epochs=args.tiger_epochs, + patience=args.tiger_patience, + batch_size=args.tiger_batch_size, + eval_batch_size=args.tiger_eval_batch_size, + num_workers=args.tiger_num_workers, + beam_size=max(20, args.top_k_hit), + top_k=args.top_k_main, + d_kv=64, + ff_dim=1024, + device=device, + random_seed=args.seed, + ) + + print("Fitting TIGER") + fit_start = time.perf_counter() + tiger.fit(train_df=train_df, val_df=val_df) + train_seconds = time.perf_counter() - fit_start + + inference_start = time.perf_counter() + reco = tiger.predict(val_df, top_k=max(args.top_k_main, args.top_k_hit)) + inference_seconds = time.perf_counter() - inference_start + + metrics = evaluate_recommendations( + reco=reco, + targets=test_targets, + catalog_size=val_df["item_id"].nunique(), + top_k_main=args.top_k_main, + top_k_hit=args.top_k_hit, + ) + return metrics, train_seconds, inference_seconds + + +def collect_dataset_info( + interactions: pd.DataFrame, + train_df: pd.DataFrame, + val_df: pd.DataFrame, + test_df: pd.DataFrame, + meta: pd.DataFrame, +) -> dict[str, object]: + timestamps = interactions["timestamp"] + return { + "n_interactions": len(interactions), + "n_users": interactions["user_id"].nunique(), + "n_items": interactions["item_id"].nunique(), + "n_train": len(train_df), + "n_val": len(val_df), + "n_test": len(test_df), + "n_meta_items": len(meta), + "min_timestamp": timestamps.min(), + "max_timestamp": timestamps.max(), + "avg_interactions_per_user": len(interactions) + / interactions["user_id"].nunique(), + "avg_interactions_per_item": len(interactions) + / interactions["item_id"].nunique(), + } + + +def write_report( + report_path: Path, + args: argparse.Namespace, + device: str, + data_info: dict[str, object], + preprocessed_paths: dict[str, Path], + results: dict[str, dict[str, float]], + started_at: datetime, + finished_at: datetime, +) -> None: + report_path.parent.mkdir(parents=True, exist_ok=True) + report_date = finished_at.strftime("%Y-%m-%d %H:%M:%S") + duration_seconds = (finished_at - started_at).total_seconds() + dataset_label = "MovieLens 20M" + + metric_names = [ + f"NDCG@{args.top_k_main}", + f"HR@{args.top_k_hit}", + f"MRR@{args.top_k_main}", + f"Coverage@{args.top_k_main}", + f"Gini@{args.top_k_main}", + ] + + lines = [ + "# SASRec vs TIGER Comparison", + "", + f"**Date:** {report_date} ", + f"**Duration:** {duration_seconds:.1f}s ", + f"**Device:** {describe_device(device)} ", + f"**Dataset:** {dataset_label} ", + f"**Report path:** `{report_path}`", + "", + "## Run Summary", + "", + "| Field | Value |", + "|---|---|", + f"| Data directory | `{args.workdir}` |", + f"| Ratings source | `{args.workdir / 'ratings.csv'}` |", + f"| Movies source | `{args.workdir / 'movies.csv'}` |", + f"| Embedding model | `{args.embedding_model}` |", + f"| Torch version | `{torch.__version__}` |", + f"| PyTorch Lightning version | `{pl.__version__}` |", + f"| Seed | {args.seed} |", + f"| Max sequence length | {args.max_length} |", + f"| top_k_main | {args.top_k_main} |", + f"| top_k_hit | {args.top_k_hit} |", + "", + "## Dataset Statistics", + "", + "| Statistic | Value |", + "|---|---:|", + f"| Interactions | {data_info['n_interactions']:,} |", + f"| Users | {data_info['n_users']:,} |", + f"| Items | {data_info['n_items']:,} |", + f"| Metadata rows | {data_info['n_meta_items']:,} |", + f"| Train interactions | {data_info['n_train']:,} |", + f"| Validation interactions | {data_info['n_val']:,} |", + f"| Test interactions | {data_info['n_test']:,} |", + f"| Avg interactions per user | {data_info['avg_interactions_per_user']:.2f} |", + f"| Avg interactions per item | {data_info['avg_interactions_per_item']:.2f} |", + f"| First timestamp | {data_info['min_timestamp']} |", + f"| Last timestamp | {data_info['max_timestamp']} |", + "", + "## Filtering and Data Preparation", + "", + "| Parameter | Value |", + "|---|---|", + f"| Minimum rating | {args.min_rating} |", + f"| Minimum item interactions | {args.min_item_interactions} |", + f"| Minimum user interactions | {args.min_user_interactions} |", + f"| User cap | {args.limit_users if args.limit_users is not None else 'None'} |", + "| Interaction count filter | Iterative k-core |", + "| Split strategy | Leave-one-out per user |", + "| TIGER metadata text | `title + genres` from `movies.csv` |", + "", + "## Preprocessed Files", + "", + "| File | Path |", + "|---|---|", + f"| Interactions | `{preprocessed_paths['interactions']}` |", + f"| Item metadata | `{preprocessed_paths['item_metadata']}` |", + f"| Train split | `{preprocessed_paths['train']}` |", + f"| Validation split | `{preprocessed_paths['val']}` |", + f"| Test split | `{preprocessed_paths['test']}` |", + f"| Test targets | `{preprocessed_paths['test_targets']}` |", + "", + "## Model Configuration", + "", + "| Parameter | SASRec | TIGER |", + "|---|---|---|", + f"| Epochs | {args.sasrec_epochs} | {args.tiger_epochs} |", + f"| Patience | {args.sasrec_patience} | {args.tiger_patience} |", + f"| Batch size | {args.sasrec_batch_size} | {args.tiger_batch_size} |", + f"| Eval batch size | - | {args.tiger_eval_batch_size} |", + f"| Num workers | {args.sasrec_num_workers} | {args.tiger_num_workers} |", + f"| Max length | {args.max_length} | {args.max_length} |", + "| Architecture notes | RecTools SASRec defaults from this script | TIGER + SIDTokenizer (RQ-VAE) |", + "", + "## Results", + "", + "| Model | Train time, s | Inference time, s | " + + " | ".join(metric_names) + + " |", + "|---|---:|---:|" + "---:" * len(metric_names) + "|", + ] + + for model_name, model_metrics in results.items(): + row = [ + model_name, + f"{model_metrics['train_seconds']:.3f}", + f"{model_metrics['inference_seconds']:.3f}", + ] + row.extend(f"{model_metrics[metric_name]:.6f}" for metric_name in metric_names) + lines.append("| " + " | ".join(row) + " |") + + lines.extend( + [ + "", + "## Notes", + "", + "- The script downloads ML-20M automatically if the data files are missing.", + "- TIGER uses stochastic SID collision resolution; this benchmark fixes it with the run seed.", + "- Coverage is computed against the total number of unique mapped items in the evaluation catalog.", + ] + ) + + report_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + print(f"Report written to {report_path}") + + +def main() -> None: + args = parse_args() + started_at = datetime.now() + pl.seed_everything(args.seed, workers=True) + + device = resolve_device(args.device) + args.workdir.mkdir(parents=True, exist_ok=True) + + ratings_path, movies_path = download_ml20m(args.workdir) + + interactions = load_interactions(ratings_path, args) + meta = build_meta(movies_path, interactions) + interactions, meta = map_ids(interactions, meta) + + train_df, val_df, test_df = loo_split(interactions, timestamp_col="timestamp") + test_targets = get_last_item_targets(test_df) + preprocessed_paths = save_preprocessed_files( + workdir=args.workdir, + interactions=interactions, + meta=meta, + train_df=train_df, + val_df=val_df, + test_df=test_df, + test_targets=test_targets, + ) + data_info = collect_dataset_info(interactions, train_df, val_df, test_df, meta) + + print( + "Prepared dataset:", + f"{data_info['n_interactions']:,} interactions,", + f"{data_info['n_users']:,} users,", + f"{data_info['n_items']:,} items", + ) + + sasrec_metrics, sasrec_train_s, sasrec_infer_s = benchmark_sasrec( + val_df=val_df, + test_targets=test_targets, + device=device, + args=args, + ) + cleanup() + + tiger_metrics, tiger_train_s, tiger_infer_s = benchmark_tiger( + train_df=train_df, + val_df=val_df, + test_targets=test_targets, + meta=meta, + device=device, + args=args, + ) + + results = { + "SASRec": { + "train_seconds": sasrec_train_s, + "inference_seconds": sasrec_infer_s, + **sasrec_metrics, + }, + "TIGER": { + "train_seconds": tiger_train_s, + "inference_seconds": tiger_infer_s, + **tiger_metrics, + }, + } + + finished_at = datetime.now() + write_report( + report_path=args.report_path, + args=args, + device=device, + data_info=data_info, + preprocessed_paths=preprocessed_paths, + results=results, + started_at=started_at, + finished_at=finished_at, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/tutorials/tiger_tutorial.ipynb b/examples/tutorials/tiger_tutorial.ipynb new file mode 100644 index 00000000..f4bc37d3 --- /dev/null +++ b/examples/tutorials/tiger_tutorial.ipynb @@ -0,0 +1,1711 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "intro", + "metadata": {}, + "source": [ + "# TIGER: Generative Retrieval with Semantic IDs\n", + "\n", + "This tutorial walks through the TIGER generative retrieval model for sequential recommendation\n", + "using the [MovieLens 20M](https://grouplens.org/datasets/movielens/20m/) dataset.\n", + "\n", + "TIGER represents items as **Semantic IDs** -- tuples of discrete codes produced by RQ-VAE.\n", + "A Transformer encoder-decoder is then trained to predict the next item's Semantic ID given a user's interaction history.\n", + "\n", + "### Pipeline overview\n", + "\n", + "1. Download and prepare ML-20M interaction data\n", + "2. Apply k-core filtering\n", + "3. Train a `SIDTokenizer` (with random embeddings for this demo)\n", + "4. Create a `TIGERModel` with the tokenizer\n", + "5. Split data with `loo_split` (leave-one-out)\n", + "6. Train with `fit(train_df, val_df)`\n", + "7. Evaluate with `evaluate(test_df)`\n", + "8. Generate recommendations with `predict(interactions)`\n", + "9. Save / load the trained model" + ] + }, + { + "cell_type": "markdown", + "id": "setup-header", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "imports", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import warnings\n", + "\n", + "from rectools.semantic.data_handling import k_core, loo_split\n", + "from rectools.semantic.tokenizer import SIDTokenizer\n", + "from rectools.semantic.tiger import TIGERModel\n", + "\n", + "warnings.simplefilter(\"ignore\")" + ] + }, + { + "cell_type": "markdown", + "id": "data-header", + "metadata": {}, + "source": [ + "## 1. Download and prepare ML-20M\n", + "\n", + "We download the MovieLens 20M dataset and prepare it as a DataFrame with\n", + "`user_id`, `item_id`, and `timestamp` columns." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "download-data", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ml-20m.zip]\n", + " End-of-central-directory signature not found. Either this file is not\n", + " a zipfile, or it constitutes one disk of a multi-part archive. In the\n", + " latter case the central directory and zipfile comment will be found on\n", + " the last disk(s) of this archive.\n", + "unzip: cannot find zipfile directory in one of ml-20m.zip or\n", + " ml-20m.zip.zip, and cannot find ml-20m.zip.ZIP, period.\n", + "CPU times: user 8.84 ms, sys: 41.8 ms, total: 50.7 ms\n", + "Wall time: 794 ms\n" + ] + } + ], + "source": [ + "%%time\n", + "!wget -q https://files.grouplens.org/datasets/movielens/ml-20m.zip -O ml-20m.zip\n", + "!unzip -o -q ml-20m.zip\n", + "!rm ml-20m.zip" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "load-data", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Interactions: 9,995,410\n", + "Users: 138,287, Items: 20,720\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idtimestamp
0110791094785665
1129591094785698
2139961094785727
311511094785734
4113741094785746
\n", + "
" + ], + "text/plain": [ + " user_id item_id timestamp\n", + "0 1 1079 1094785665\n", + "1 1 2959 1094785698\n", + "2 1 3996 1094785727\n", + "3 1 151 1094785734\n", + "4 1 1374 1094785746" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ratings = pd.read_csv(\"ml-20m/ratings.csv\")\n", + "ratings.columns = [\"user_id\", \"item_id\", \"rating\", \"timestamp\"]\n", + "\n", + "# Keep only positive interactions (rating >= 4)\n", + "interactions = ratings[ratings[\"rating\"] >= 4][[\"user_id\", \"item_id\", \"timestamp\"]].copy()\n", + "interactions = interactions.sort_values([\"user_id\", \"timestamp\"]).reset_index(drop=True)\n", + "\n", + "print(f\"Interactions: {len(interactions):,}\")\n", + "print(f\"Users: {interactions['user_id'].nunique():,}, Items: {interactions['item_id'].nunique():,}\")\n", + "interactions.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "load-movies", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
item_idtitlegenres
01Toy Story (1995)Adventure|Animation|Children|Comedy|Fantasy
12Jumanji (1995)Adventure|Children|Fantasy
23Grumpier Old Men (1995)Comedy|Romance
34Waiting to Exhale (1995)Comedy|Drama|Romance
45Father of the Bride Part II (1995)Comedy
\n", + "
" + ], + "text/plain": [ + " item_id title \\\n", + "0 1 Toy Story (1995) \n", + "1 2 Jumanji (1995) \n", + "2 3 Grumpier Old Men (1995) \n", + "3 4 Waiting to Exhale (1995) \n", + "4 5 Father of the Bride Part II (1995) \n", + "\n", + " genres \n", + "0 Adventure|Animation|Children|Comedy|Fantasy \n", + "1 Adventure|Children|Fantasy \n", + "2 Comedy|Romance \n", + "3 Comedy|Drama|Romance \n", + "4 Comedy " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "movies = pd.read_csv(\"ml-20m/movies.csv\")\n", + "movies.columns = [\"item_id\", \"title\", \"genres\"]\n", + "movies.head()" + ] + }, + { + "cell_type": "markdown", + "id": "kcore-header", + "metadata": {}, + "source": [ + "## 2. K-core filtering\n", + "\n", + "`k_core()` iteratively removes users and items with fewer than the specified minimum interactions.\n", + "This ensures that every user and item has enough signal for sequential modeling." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "kcore", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "After k-core: 9,977,451 interactions\n", + "Users: 136,674, Items: 13,680\n" + ] + } + ], + "source": [ + "interactions = k_core(\n", + " interactions,\n", + " user_min_interactions=5,\n", + " item_min_interactions=5,\n", + ")\n", + "\n", + "print(f\"After k-core: {len(interactions):,} interactions\")\n", + "print(f\"Users: {interactions['user_id'].nunique():,}, Items: {interactions['item_id'].nunique():,}\")" + ] + }, + { + "cell_type": "markdown", + "id": "tokenizer-header", + "metadata": {}, + "source": [ + "## 3. Train a tokenizer\n", + "\n", + "The `SIDTokenizer` maps each item ID to a **Semantic ID** -- a tuple of discrete codes.\n", + "It uses RQ-VAE (or RK-means) to quantize item embeddings into hierarchical codebooks.\n", + "\n", + "### Embedding items with sentence-transformers\n", + "\n", + "Here we embed item metadata (title, genres, description) using\n", + "a sentence-transformers model and pass the resulting embeddings to `SIDTokenizer.fit()`:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4f6e3d21", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from sentence_transformers import SentenceTransformer\n", + "from tqdm.auto import trange\n", + "\n", + "def encode_texts(texts: list, model_name: str, device: str = \"cpu\", batch_size: int = 256) -> np.ndarray:\n", + " \"\"\"Encode texts into embeddings using a SentenceTransformer model.\n", + "\n", + " Parameters\n", + " ----------\n", + " texts : list of str\n", + " Texts to encode.\n", + " model_name : str\n", + " Name or path of a sentence-transformers model.\n", + " device : str\n", + " Device to run the model on.\n", + " batch_size : int\n", + " Batch size for encoding.\n", + "\n", + " Returns\n", + " -------\n", + " np.ndarray\n", + " Embedding matrix of shape ``(len(texts), embed_dim)``.\n", + " \"\"\"\n", + " embeds = []\n", + " model = SentenceTransformer(model_name).to(device).eval()\n", + " with torch.no_grad():\n", + " for start in trange(0, len(texts), batch_size, desc=\"Embedding the metadata\"):\n", + " end = min(start + batch_size, len(texts))\n", + " embeds.append(model.encode(texts[start:end]))\n", + "\n", + " return np.vstack(embeds)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "2ea6e31b-0a9c-4fa3-8464-eb35df9dd87b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0d0b8b2fd59648838f70c6f0a4609108", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Embedding the metadata: 0%| | 0/54 [00:00\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
SIDitem_idtext
0(235, 229, 202)421Black Beauty (1994) Adventure Children Drama
1(235, 229, 202)986Fly Away Home (1996) Adventure Children
2(10, 64, 101)502Next Karate Kid, The (1994) Action Children Ro...
3(10, 64, 101)2422Karate Kid, Part III, The (1989) Action Advent...
4(15, 132, 12)913Maltese Falcon, The (1941) Film-Noir Mystery
5(15, 132, 12)8228Maltese Falcon, The (a.k.a. Dangerous Female) ...
6(208, 107, 255)1007Apple Dumpling Gang, The (1975) Children Comed...
7(208, 107, 255)2016Apple Dumpling Gang Rides Again, The (1979) Ch...
8(99, 84, 111)1644I Know What You Did Last Summer (1997) Horror ...
9(99, 84, 111)2338I Still Know What You Did Last Summer (1998) H...
10(183, 144, 185)2124Addams Family, The (1991) Children Comedy Fantasy
11(183, 144, 185)27075Addams Family Reunion (1998) Children Comedy F...
12(93, 110, 17)3068Verdict, The (1982) Drama Mystery
13(93, 110, 17)82121Verdict, The (1946) Crime Drama Film-Noir Myst...
14(180, 31, 0)3149Diamonds (1999) Mystery
15(180, 31, 0)3219Pacific Heights (1990) Mystery Thriller
16(180, 60, 223)3660Puppet Master (1989) Horror Sci-Fi Thriller
17(180, 60, 223)3661Puppet Master II (1991) Horror Sci-Fi Thriller
18(179, 99, 81)4544Short Circuit 2 (1988) Comedy Sci-Fi
19(179, 99, 81)4545Short Circuit (1986) Comedy Sci-Fi
20(31, 154, 77)5575Alias Betty (Betty Fisher et autres histoires)...
21(31, 154, 77)35082Lila Says (Lila dit ça) (2004) Crime Drama Rom...
22(11, 204, 216)5736Faces of Death 3 (1985) Documentary Horror
23(11, 204, 216)5739Faces of Death 6 (1996) Documentary Horror
24(60, 149, 155)6159All the Real Girls (2003) Drama Romance
25(60, 149, 155)67365After Sex (2007) Drama Romance
26(2, 109, 34)6407Walk, Don't Run (1966) Comedy Romance
27(2, 109, 34)30742Some Came Running (1958) Drama Romance
28(14, 164, 68)6428Two Mules for Sister Sara (1970) Comedy War We...
29(14, 164, 68)7897Ballad of Cable Hogue, The (1970) Comedy Western
30(230, 2, 206)8730To End All Wars (2001) Action Drama War
31(230, 2, 206)102445Star Trek Into Darkness (2013) Action Adventur...
32(209, 218, 225)8925Spinning Boris (2003) Comedy Drama
33(209, 218, 225)55190Love and Other Disasters (2006) Comedy Romance
34(199, 49, 227)8938Tarnation (2003) Documentary
35(199, 49, 227)110387Unknown Known, The (2013) Documentary
36(8, 220, 31)26007Unknown Soldier, The (Tuntematon sotilas) (195...
37(8, 220, 31)26560Unknown Soldier, The (Tuntematon sotilas) (198...
38(223, 206, 99)26270Lone Wolf and Cub: Baby Cart at the River Styx...
39(223, 206, 99)62803Lone Wolf and Cub: Baby Cart in Peril (Kozure ...
40(197, 115, 131)43908London (2005) Drama
41(197, 115, 131)70948London (1994) Documentary
42(157, 159, 6)59315Iron Man (2008) Action Adventure Sci-Fi
43(157, 159, 6)77561Iron Man 2 (2010) Action Adventure Sci-Fi Thri...
44(116, 33, 244)61026Red Cliff (Chi bi) (2008) Action Adventure Dra...
45(116, 33, 244)68486Red Cliff Part II (Chi Bi Xia: Jue Zhan Tian X...
46(196, 29, 161)62390Autism: The Musical (2007) Documentary
47(196, 29, 161)85190Public Speaking (2010) Documentary
48(109, 81, 219)65682Underworld: Rise of the Lycans (2009) Action F...
49(109, 81, 219)91974Underworld: Awakening (2012) Action Fantasy Ho...
50(170, 55, 67)77330Red Riding: 1980 (2009) Crime Drama Mystery
51(170, 55, 67)77359Red Riding: 1983 (2009) Crime Drama Mystery
52(118, 203, 203)86355Atlas Shrugged: Part 1 (2011) Drama Mystery Sc...
53(118, 203, 203)97324Atlas Shrugged: Part II (2012) Drama Mystery S...
54(250, 128, 116)88129Drive (2011) Crime Drama Film-Noir Thriller
55(250, 128, 116)101137Dead Man Down (2013) Action Crime Drama Romanc...
56(34, 147, 143)106204Pieta (2013) Drama
57(34, 147, 143)111249Belle (2013) Drama
\n", + "" + ], + "text/plain": [ + " SID item_id \\\n", + "0 (235, 229, 202) 421 \n", + "1 (235, 229, 202) 986 \n", + "2 (10, 64, 101) 502 \n", + "3 (10, 64, 101) 2422 \n", + "4 (15, 132, 12) 913 \n", + "5 (15, 132, 12) 8228 \n", + "6 (208, 107, 255) 1007 \n", + "7 (208, 107, 255) 2016 \n", + "8 (99, 84, 111) 1644 \n", + "9 (99, 84, 111) 2338 \n", + "10 (183, 144, 185) 2124 \n", + "11 (183, 144, 185) 27075 \n", + "12 (93, 110, 17) 3068 \n", + "13 (93, 110, 17) 82121 \n", + "14 (180, 31, 0) 3149 \n", + "15 (180, 31, 0) 3219 \n", + "16 (180, 60, 223) 3660 \n", + "17 (180, 60, 223) 3661 \n", + "18 (179, 99, 81) 4544 \n", + "19 (179, 99, 81) 4545 \n", + "20 (31, 154, 77) 5575 \n", + "21 (31, 154, 77) 35082 \n", + "22 (11, 204, 216) 5736 \n", + "23 (11, 204, 216) 5739 \n", + "24 (60, 149, 155) 6159 \n", + "25 (60, 149, 155) 67365 \n", + "26 (2, 109, 34) 6407 \n", + "27 (2, 109, 34) 30742 \n", + "28 (14, 164, 68) 6428 \n", + "29 (14, 164, 68) 7897 \n", + "30 (230, 2, 206) 8730 \n", + "31 (230, 2, 206) 102445 \n", + "32 (209, 218, 225) 8925 \n", + "33 (209, 218, 225) 55190 \n", + "34 (199, 49, 227) 8938 \n", + "35 (199, 49, 227) 110387 \n", + "36 (8, 220, 31) 26007 \n", + "37 (8, 220, 31) 26560 \n", + "38 (223, 206, 99) 26270 \n", + "39 (223, 206, 99) 62803 \n", + "40 (197, 115, 131) 43908 \n", + "41 (197, 115, 131) 70948 \n", + "42 (157, 159, 6) 59315 \n", + "43 (157, 159, 6) 77561 \n", + "44 (116, 33, 244) 61026 \n", + "45 (116, 33, 244) 68486 \n", + "46 (196, 29, 161) 62390 \n", + "47 (196, 29, 161) 85190 \n", + "48 (109, 81, 219) 65682 \n", + "49 (109, 81, 219) 91974 \n", + "50 (170, 55, 67) 77330 \n", + "51 (170, 55, 67) 77359 \n", + "52 (118, 203, 203) 86355 \n", + "53 (118, 203, 203) 97324 \n", + "54 (250, 128, 116) 88129 \n", + "55 (250, 128, 116) 101137 \n", + "56 (34, 147, 143) 106204 \n", + "57 (34, 147, 143) 111249 \n", + "\n", + " text \n", + "0 Black Beauty (1994) Adventure Children Drama \n", + "1 Fly Away Home (1996) Adventure Children \n", + "2 Next Karate Kid, The (1994) Action Children Ro... \n", + "3 Karate Kid, Part III, The (1989) Action Advent... \n", + "4 Maltese Falcon, The (1941) Film-Noir Mystery \n", + "5 Maltese Falcon, The (a.k.a. Dangerous Female) ... \n", + "6 Apple Dumpling Gang, The (1975) Children Comed... \n", + "7 Apple Dumpling Gang Rides Again, The (1979) Ch... \n", + "8 I Know What You Did Last Summer (1997) Horror ... \n", + "9 I Still Know What You Did Last Summer (1998) H... \n", + "10 Addams Family, The (1991) Children Comedy Fantasy \n", + "11 Addams Family Reunion (1998) Children Comedy F... \n", + "12 Verdict, The (1982) Drama Mystery \n", + "13 Verdict, The (1946) Crime Drama Film-Noir Myst... \n", + "14 Diamonds (1999) Mystery \n", + "15 Pacific Heights (1990) Mystery Thriller \n", + "16 Puppet Master (1989) Horror Sci-Fi Thriller \n", + "17 Puppet Master II (1991) Horror Sci-Fi Thriller \n", + "18 Short Circuit 2 (1988) Comedy Sci-Fi \n", + "19 Short Circuit (1986) Comedy Sci-Fi \n", + "20 Alias Betty (Betty Fisher et autres histoires)... \n", + "21 Lila Says (Lila dit ça) (2004) Crime Drama Rom... \n", + "22 Faces of Death 3 (1985) Documentary Horror \n", + "23 Faces of Death 6 (1996) Documentary Horror \n", + "24 All the Real Girls (2003) Drama Romance \n", + "25 After Sex (2007) Drama Romance \n", + "26 Walk, Don't Run (1966) Comedy Romance \n", + "27 Some Came Running (1958) Drama Romance \n", + "28 Two Mules for Sister Sara (1970) Comedy War We... \n", + "29 Ballad of Cable Hogue, The (1970) Comedy Western \n", + "30 To End All Wars (2001) Action Drama War \n", + "31 Star Trek Into Darkness (2013) Action Adventur... \n", + "32 Spinning Boris (2003) Comedy Drama \n", + "33 Love and Other Disasters (2006) Comedy Romance \n", + "34 Tarnation (2003) Documentary \n", + "35 Unknown Known, The (2013) Documentary \n", + "36 Unknown Soldier, The (Tuntematon sotilas) (195... \n", + "37 Unknown Soldier, The (Tuntematon sotilas) (198... \n", + "38 Lone Wolf and Cub: Baby Cart at the River Styx... \n", + "39 Lone Wolf and Cub: Baby Cart in Peril (Kozure ... \n", + "40 London (2005) Drama \n", + "41 London (1994) Documentary \n", + "42 Iron Man (2008) Action Adventure Sci-Fi \n", + "43 Iron Man 2 (2010) Action Adventure Sci-Fi Thri... \n", + "44 Red Cliff (Chi bi) (2008) Action Adventure Dra... \n", + "45 Red Cliff Part II (Chi Bi Xia: Jue Zhan Tian X... \n", + "46 Autism: The Musical (2007) Documentary \n", + "47 Public Speaking (2010) Documentary \n", + "48 Underworld: Rise of the Lycans (2009) Action F... \n", + "49 Underworld: Awakening (2012) Action Fantasy Ho... \n", + "50 Red Riding: 1980 (2009) Crime Drama Mystery \n", + "51 Red Riding: 1983 (2009) Crime Drama Mystery \n", + "52 Atlas Shrugged: Part 1 (2011) Drama Mystery Sc... \n", + "53 Atlas Shrugged: Part II (2012) Drama Mystery S... \n", + "54 Drive (2011) Crime Drama Film-Noir Thriller \n", + "55 Dead Man Down (2013) Action Crime Drama Romanc... \n", + "56 Pieta (2013) Drama \n", + "57 Belle (2013) Drama " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from utils import find_conflicts_df\n", + "\n", + "find_conflicts_df(tokenizer.id2sid, meta)" + ] + }, + { + "cell_type": "markdown", + "id": "model-header", + "metadata": {}, + "source": [ + "## 4. Create a TIGERModel\n", + "\n", + "The `TIGERModel` wraps the TIGER Transformer encoder-decoder.\n", + "It accepts a pretrained tokenizer and model/training hyperparameters." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "create-model", + "metadata": {}, + "outputs": [], + "source": [ + "tiger = TIGERModel(\n", + " tokenizer=tokenizer,\n", + " hidden_units=128,\n", + " num_blocks=4,\n", + " num_heads=6,\n", + " dropout_rate=0.1,\n", + " max_length=20,\n", + " # Training hyperparams\n", + " lr=1e-3,\n", + " lr_schedule=\"cosine\",\n", + " max_epochs=3,\n", + " patience=None,\n", + " batch_size=256,\n", + " eval_batch_size=64,\n", + " beam_size=20,\n", + " top_k=10,\n", + " d_kv=64,\n", + " device=\"cuda\"\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "split-header", + "metadata": {}, + "source": [ + "## 5. Split data with leave-one-out\n", + "\n", + "`loo_split()` performs a leave-one-out split on the interactions DataFrame:\n", + "- **train**: all interactions except the last 2 per user\n", + "- **val**: all interactions except the last 1 per user\n", + "- **test**: all interactions\n", + "\n", + "Users with fewer than 3 interactions are dropped." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "split", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: 9,704,103 interactions, 136,674 users\n", + "Val: 9,840,777 interactions, 136,674 users\n", + "Test: 9,977,451 interactions, 136,674 users\n" + ] + } + ], + "source": [ + "train_df, val_df, test_df = loo_split(interactions)\n", + "\n", + "print(f\"Train: {len(train_df):,} interactions, {train_df['user_id'].nunique():,} users\")\n", + "print(f\"Val: {len(val_df):,} interactions, {val_df['user_id'].nunique():,} users\")\n", + "print(f\"Test: {len(test_df):,} interactions, {test_df['user_id'].nunique():,} users\")" + ] + }, + { + "cell_type": "markdown", + "id": "train-header", + "metadata": {}, + "source": [ + "## 6. Train the model\n", + "\n", + "`fit()` accepts pre-split train and validation DataFrames.\n", + "You control the split strategy -- the model does not split internally." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "train", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n", + "Trainer already configured with model summary callbacks: []. Skipping setting a default `ModelSummary` callback.\n", + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "You are using a CUDA device ('NVIDIA H100 80GB HBM3 MIG 2g.20gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n", + "\n", + " | Name | Type | Params | Mode \n", + "-------------------------------------------\n", + "0 | model | TIGERNet | 3.6 M | train\n", + "-------------------------------------------\n", + "3.6 M Trainable params\n", + "0 Non-trainable params\n", + "3.6 M Total params\n", + "14.455 Total estimated model params size (MB)\n", + "150 Modules in train mode\n", + "0 Modules in eval mode\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "82895486bd2043d3983219c3ac546537", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Sanity Checking: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric DataLoader 0 ┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│ Coverage@10 0.0023603341542184353 │\n", + "│ Gini@10 0.7614988684654236 │\n", + "│ Hit@10 0.0728229209780693 │\n", + "│ MRR@10 0.024664802476763725 │\n", + "│ NDCG@10 0.03576871380209923 │\n", + "└───────────────────────────┴───────────────────────────┘\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m Coverage@10 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0023603341542184353 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m Gini@10 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.7614988684654236 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m Hit@10 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.0728229209780693 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m MRR@10 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.024664802476763725 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m NDCG@10 \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 0.03576871380209923 \u001b[0m\u001b[35m \u001b[0m│\n", + "└───────────────────────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Hit@10: 0.0728\n", + "NDCG@10: 0.0358\n", + "MRR@10: 0.0247\n", + "Gini@10: 0.7615\n", + "Coverage@10: 0.0024\n" + ] + } + ], + "source": [ + "results = tiger.evaluate(test_df)\n", + "\n", + "for metric, value in results.items():\n", + " print(f\"{metric}: {value:.4f}\")" + ] + }, + { + "cell_type": "markdown", + "id": "predict-header", + "metadata": {}, + "source": [ + "## 8. Generate recommendations\n", + "\n", + "`predict()` takes a DataFrame of user interaction histories and returns\n", + "a DataFrame of recommendations with columns: `user_id`, `item_id`, `score`, `rank`.\n", + "\n", + "We can join with movie metadata to see the actual titles." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "predict", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
user_iditem_idscoreranktitle
015952-4.5950251Lord of the Rings: The Two Towers, The (2002)
113578-4.6733812Gladiator (2000)
2159315-4.7512423Iron Man (2008)
316539-4.8580474Pirates of the Caribbean: The Curse of the Bla...
415349-5.8122825Spider-Man (2002)
521196-3.8954081Star Wars: Episode V - The Empire Strikes Back...
62260-4.1246452Star Wars: Episode IV - A New Hope (1977)
721210-4.1798703Star Wars: Episode VI - Return of the Jedi (1983)
82541-4.4146594Blade Runner (1982)
921240-4.4788415Terminator, The (1984)
1031210-4.0719711Star Wars: Episode VI - Return of the Jedi (1983)
113780-4.4140122Independence Day (a.k.a. ID4) (1996)
123260-4.5106683Star Wars: Episode IV - A New Hope (1977)
1331196-4.5729674Star Wars: Episode V - The Empire Strikes Back...
1431240-4.6479295Terminator, The (1984)
154780-3.8602511Independence Day (a.k.a. ID4) (1996)
164648-4.6089922Mission: Impossible (1996)
174474-4.6941913In the Line of Fire (1993)
184736-4.7149364Twister (1996)
19495-4.9400845Broken Arrow (1996)
205260-4.8909901Star Wars: Episode IV - A New Hope (1977)
2151196-5.1302782Star Wars: Episode V - The Empire Strikes Back...
2251198-5.2161893Raiders of the Lost Ark (Indiana Jones and the...
235541-5.2270414Blade Runner (1982)
245457-5.2676865Fugitive, The (1993)
\n", + "
" + ], + "text/plain": [ + " user_id item_id score rank \\\n", + "0 1 5952 -4.595025 1 \n", + "1 1 3578 -4.673381 2 \n", + "2 1 59315 -4.751242 3 \n", + "3 1 6539 -4.858047 4 \n", + "4 1 5349 -5.812282 5 \n", + "5 2 1196 -3.895408 1 \n", + "6 2 260 -4.124645 2 \n", + "7 2 1210 -4.179870 3 \n", + "8 2 541 -4.414659 4 \n", + "9 2 1240 -4.478841 5 \n", + "10 3 1210 -4.071971 1 \n", + "11 3 780 -4.414012 2 \n", + "12 3 260 -4.510668 3 \n", + "13 3 1196 -4.572967 4 \n", + "14 3 1240 -4.647929 5 \n", + "15 4 780 -3.860251 1 \n", + "16 4 648 -4.608992 2 \n", + "17 4 474 -4.694191 3 \n", + "18 4 736 -4.714936 4 \n", + "19 4 95 -4.940084 5 \n", + "20 5 260 -4.890990 1 \n", + "21 5 1196 -5.130278 2 \n", + "22 5 1198 -5.216189 3 \n", + "23 5 541 -5.227041 4 \n", + "24 5 457 -5.267686 5 \n", + "\n", + " title \n", + "0 Lord of the Rings: The Two Towers, The (2002) \n", + "1 Gladiator (2000) \n", + "2 Iron Man (2008) \n", + "3 Pirates of the Caribbean: The Curse of the Bla... \n", + "4 Spider-Man (2002) \n", + "5 Star Wars: Episode V - The Empire Strikes Back... \n", + "6 Star Wars: Episode IV - A New Hope (1977) \n", + "7 Star Wars: Episode VI - Return of the Jedi (1983) \n", + "8 Blade Runner (1982) \n", + "9 Terminator, The (1984) \n", + "10 Star Wars: Episode VI - Return of the Jedi (1983) \n", + "11 Independence Day (a.k.a. ID4) (1996) \n", + "12 Star Wars: Episode IV - A New Hope (1977) \n", + "13 Star Wars: Episode V - The Empire Strikes Back... \n", + "14 Terminator, The (1984) \n", + "15 Independence Day (a.k.a. ID4) (1996) \n", + "16 Mission: Impossible (1996) \n", + "17 In the Line of Fire (1993) \n", + "18 Twister (1996) \n", + "19 Broken Arrow (1996) \n", + "20 Star Wars: Episode IV - A New Hope (1977) \n", + "21 Star Wars: Episode V - The Empire Strikes Back... \n", + "22 Raiders of the Lost Ark (Indiana Jones and the... \n", + "23 Blade Runner (1982) \n", + "24 Fugitive, The (1993) " + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sample_users = interactions[\"user_id\"].unique()[:5]\n", + "sample_interactions = interactions[interactions[\"user_id\"].isin(sample_users)]\n", + "\n", + "recommendations = tiger.predict(sample_interactions, top_k=5)\n", + "recommendations.merge(movies[[\"item_id\", \"title\"]], on=\"item_id\", how=\"left\")" + ] + }, + { + "cell_type": "markdown", + "id": "save-header", + "metadata": {}, + "source": [ + "## 9. Save and load the model\n", + "\n", + "`save()` writes the tokenizer, model weights, and config to a directory.\n", + "`load()` reconstructs the full `TIGERModel` from that directory." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "save-load", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saved to /tmp/tmp2mudyxqd/tiger_model\n", + "Contents: ['tokenizer.pt', 'model.pt', 'config.json']\n" + ] + }, + { + "ename": "UnpicklingError", + "evalue": "Weights only load failed. This file can still be loaded, to do so you have two options, \u001b[1mdo those steps only if you trust the source of the checkpoint\u001b[0m. \n\t(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.\n\t(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.\n\tWeightsUnpickler error: Unsupported global: GLOBAL rectools.semantic.tokenizer.rqvae.RQVAE was not an allowed global by default. Please use `torch.serialization.add_safe_globals([rectools.semantic.tokenizer.rqvae.RQVAE])` or the `torch.serialization.safe_globals([rectools.semantic.tokenizer.rqvae.RQVAE])` context manager to allowlist this global if you trust this class/function.\n\nCheck the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mUnpicklingError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[14], line 10\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSaved to \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msave_path\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 8\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mContents: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mos\u001b[38;5;241m.\u001b[39mlistdir(save_path)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 10\u001b[0m loaded_tiger \u001b[38;5;241m=\u001b[39m \u001b[43mTIGERModel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124mLoaded model with \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(loaded_tiger\u001b[38;5;241m.\u001b[39mtokenizer)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m item SIDs\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 12\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCodebook sizes: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mloaded_tiger\u001b[38;5;241m.\u001b[39mcodebook_sizes\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/workspace/rectools/rectools/semantic/tiger/model.py:353\u001b[0m, in \u001b[0;36mTIGERModel.load\u001b[0;34m(cls, directory, device)\u001b[0m\n\u001b[1;32m 350\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mopen\u001b[39m(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(directory, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mconfig.json\u001b[39m\u001b[38;5;124m\"\u001b[39m), encoding\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mutf-8\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m f:\n\u001b[1;32m 351\u001b[0m config \u001b[38;5;241m=\u001b[39m json\u001b[38;5;241m.\u001b[39mload(f)\n\u001b[0;32m--> 353\u001b[0m tokenizer \u001b[38;5;241m=\u001b[39m \u001b[43mSIDTokenizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mjoin\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdirectory\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtokenizer.pt\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 355\u001b[0m tiger_model \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m(tokenizer\u001b[38;5;241m=\u001b[39mtokenizer, device\u001b[38;5;241m=\u001b[39mdevice, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mconfig)\n\u001b[1;32m 357\u001b[0m state_dict \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mload(os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mjoin(directory, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmodel.pt\u001b[39m\u001b[38;5;124m\"\u001b[39m), map_location\u001b[38;5;241m=\u001b[39mdevice)\n", + "File \u001b[0;32m~/workspace/rectools/rectools/semantic/tokenizer/model.py:165\u001b[0m, in \u001b[0;36mSIDTokenizer.load\u001b[0;34m(cls, path, device)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[38;5;129m@classmethod\u001b[39m\n\u001b[1;32m 146\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mload\u001b[39m(\n\u001b[1;32m 147\u001b[0m \u001b[38;5;28mcls\u001b[39m,\n\u001b[1;32m 148\u001b[0m path: Union[\u001b[38;5;28mstr\u001b[39m, PathLike[\u001b[38;5;28mstr\u001b[39m]],\n\u001b[1;32m 149\u001b[0m device: Optional[\u001b[38;5;28mstr\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 150\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mSIDTokenizer\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 151\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Load serialized SIDTokenizer.\u001b[39;00m\n\u001b[1;32m 152\u001b[0m \n\u001b[1;32m 153\u001b[0m \u001b[38;5;124;03m Parameters\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[38;5;124;03m A de-serialized tokenizer.\u001b[39;00m\n\u001b[1;32m 164\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 165\u001b[0m loaded_data \u001b[38;5;241m=\u001b[39m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mload\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_location\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 167\u001b[0m tok \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mcls\u001b[39m(\n\u001b[1;32m 168\u001b[0m quantizer\u001b[38;5;241m=\u001b[39mloaded_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquantizer\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 169\u001b[0m id2sid\u001b[38;5;241m=\u001b[39mloaded_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mid2sid\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m 170\u001b[0m device\u001b[38;5;241m=\u001b[39mdevice,\n\u001b[1;32m 171\u001b[0m )\n\u001b[1;32m 173\u001b[0m \u001b[38;5;66;03m# Remind the codebooks to not re-init themselves\u001b[39;00m\n", + "File \u001b[0;32m~/.clearml/venvs-builds/3.12/lib/python3.12/site-packages/torch/serialization.py:1548\u001b[0m, in \u001b[0;36mload\u001b[0;34m(f, map_location, pickle_module, weights_only, mmap, **pickle_load_args)\u001b[0m\n\u001b[1;32m 1540\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _load(\n\u001b[1;32m 1541\u001b[0m opened_zipfile,\n\u001b[1;32m 1542\u001b[0m map_location,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1545\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpickle_load_args,\n\u001b[1;32m 1546\u001b[0m )\n\u001b[1;32m 1547\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m pickle\u001b[38;5;241m.\u001b[39mUnpicklingError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[0;32m-> 1548\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m pickle\u001b[38;5;241m.\u001b[39mUnpicklingError(_get_wo_message(\u001b[38;5;28mstr\u001b[39m(e))) \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m 1549\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m _load(\n\u001b[1;32m 1550\u001b[0m opened_zipfile,\n\u001b[1;32m 1551\u001b[0m map_location,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1554\u001b[0m \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpickle_load_args,\n\u001b[1;32m 1555\u001b[0m )\n\u001b[1;32m 1556\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m mmap:\n", + "\u001b[0;31mUnpicklingError\u001b[0m: Weights only load failed. This file can still be loaded, to do so you have two options, \u001b[1mdo those steps only if you trust the source of the checkpoint\u001b[0m. \n\t(1) In PyTorch 2.6, we changed the default value of the `weights_only` argument in `torch.load` from `False` to `True`. Re-running `torch.load` with `weights_only` set to `False` will likely succeed, but it can result in arbitrary code execution. Do it only if you got the file from a trusted source.\n\t(2) Alternatively, to load with `weights_only=True` please check the recommended steps in the following error message.\n\tWeightsUnpickler error: Unsupported global: GLOBAL rectools.semantic.tokenizer.rqvae.RQVAE was not an allowed global by default. Please use `torch.serialization.add_safe_globals([rectools.semantic.tokenizer.rqvae.RQVAE])` or the `torch.serialization.safe_globals([rectools.semantic.tokenizer.rqvae.RQVAE])` context manager to allowlist this global if you trust this class/function.\n\nCheck the documentation of torch.load to learn more about types accepted by default with weights_only https://pytorch.org/docs/stable/generated/torch.load.html." + ] + } + ], + "source": [ + "import tempfile\n", + "import os\n", + "\n", + "with tempfile.TemporaryDirectory() as tmpdir:\n", + " save_path = os.path.join(tmpdir, \"tiger_model\")\n", + " tiger.save(save_path)\n", + " print(f\"Saved to {save_path}\")\n", + " print(f\"Contents: {os.listdir(save_path)}\")\n", + "\n", + " loaded_tiger = TIGERModel.load(save_path)\n", + " print(f\"\\nLoaded model with {len(loaded_tiger.tokenizer)} item SIDs\")\n", + " print(f\"Codebook sizes: {loaded_tiger.codebook_sizes}\")" + ] + }, + { + "cell_type": "markdown", + "id": "summary", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "The TIGER pipeline in RecTools:\n", + "\n", + "| Step | Function / Method | Input | Output |\n", + "|------|-------------------|-------|--------|\n", + "| Filtering | `k_core(df)` | Interactions DataFrame | Filtered DataFrame |\n", + "| Embedding | `encode_texts(texts, model)` | List of text strings | numpy embedding matrix |\n", + "| Tokenization | `SIDTokenizer(...).fit(ids, embs)` | Item IDs + embeddings | `SIDTokenizer` |\n", + "| Splitting | `loo_split(df)` | Interactions DataFrame | train, val, test DataFrames |\n", + "| Model creation | `TIGERModel(tokenizer)` | `SIDTokenizer` + hyperparams | `TIGERModel` |\n", + "| Training | `model.fit(train, val)` | Train + val DataFrames | Trained model (in-place) |\n", + "| Evaluation | `model.evaluate(test)` | Test DataFrame | Metrics dict |\n", + "| Prediction | `model.predict(df)` | Interactions DataFrame | Recommendations DataFrame |\n", + "| Persistence | `model.save(dir)` / `TIGERModel.load(dir)` | Directory path | Saved/loaded model |" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "3.12", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/tutorials/utils.py b/examples/tutorials/utils.py index 8fdf8a8a..b8ea7d54 100644 --- a/examples/tutorials/utils.py +++ b/examples/tutorials/utils.py @@ -16,6 +16,7 @@ import os import typing as tp import warnings +from collections import defaultdict from pathlib import Path import matplotlib.pyplot as plt @@ -329,3 +330,61 @@ def get_results(path_to_load_res: str, metrics_to_show: tp.List[str], show_loss: ) pivot_results.columns = pivot_results.columns.droplevel(1) return pivot_results[metrics_to_show] + + +def find_conflicts(id2sid: dict) -> dict[tuple[int, ...], list[int]]: + """Return a mapping from SID -> list of item IDs for every SID shared by 2+ items. + + Parameters + ---------- + id2sid : dict + Mapping from item ID to SID tuple, e.g. from ``SIDTokenizer.id2sid``. + + Returns + ------- + dict[tuple[int, ...], list[int]] + Conflicting SIDs mapped to the item IDs that share them. + """ + sid2ids: dict[tuple[int, ...], list[int]] = defaultdict(list) + for item_id, sid in id2sid.items(): + sid2ids[sid].append(item_id) + return {sid: ids for sid, ids in sid2ids.items() if len(ids) > 1} + + +def find_conflicts_df( + id2sid: dict, + metadata: pd.DataFrame, + item_col: str = "item_id", + text_col: str = "text", +) -> pd.DataFrame: + """Find conflicting SIDs and return a DataFrame enriched with item metadata. + + Parameters + ---------- + id2sid : dict + Mapping from item ID to SID tuple, e.g. from ``SIDTokenizer.id2sid``. + metadata : pd.DataFrame + Item metadata with at least ``item_col`` and ``text_col`` columns. + item_col : str + Name of the item ID column in ``metadata``. + text_col : str + Name of the text column in ``metadata``. + + Returns + ------- + pd.DataFrame + Columns: ``SID``, ``item_id``, ``text``. One row per conflicting item, + sorted by cluster size descending. Empty if no conflicts. + """ + conflicts = find_conflicts(id2sid) + if not conflicts: + return pd.DataFrame(columns=["SID", "item_id", "text"]) + + id2text = dict(zip(metadata[item_col], metadata[text_col])) + + rows = [] + for sid, ids in sorted(conflicts.items(), key=lambda kv: -len(kv[1])): + for item_id in ids: + rows.append({"SID": sid, "item_id": item_id, "text": id2text.get(item_id, "")}) + + return pd.DataFrame(rows, columns=["SID", "item_id", "text"]) diff --git a/poetry.lock b/poetry.lock index b1b15e5b..9c0bcb2b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,15 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. + +[[package]] +name = "absl-py" +version = "2.3.1" +description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." +optional = true +python-versions = ">=3.8" +files = [ + {file = "absl_py-2.3.1-py3-none-any.whl", hash = "sha256:eeecf07f0c2a93ace0772c92e596ace6d3d3996c042b2128459aaae2a76de11d"}, + {file = "absl_py-2.3.1.tar.gz", hash = "sha256:a97820526f7fbfd2ec1bce83f3f25e3a14840dac0d8e02a0b71cd75db3f77fc9"}, +] [[package]] name = "aiohappyeyeballs" @@ -1631,6 +1642,17 @@ files = [ {file = "imagesize-1.4.1.tar.gz", hash = "sha256:69150444affb9cb0d5cc5a92b3676f0b2fb7cd9ae39e947a5e11a36b4497cd4a"}, ] +[[package]] +name = "immutabledict" +version = "4.3.1" +description = "Immutable wrapper around dictionaries (a fork of frozendict)" +optional = true +python-versions = "<4.0,>=3.8" +files = [ + {file = "immutabledict-4.3.1-py3-none-any.whl", hash = "sha256:c9facdc0ff30fdb8e35bd16532026cac472a549e182c94fa201b51b25e4bf7bf"}, + {file = "immutabledict-4.3.1.tar.gz", hash = "sha256:f844a669106cfdc73f47b1a9da003782fb17dc955a54c80972e0d93d1c63c514"}, +] + [[package]] name = "implicit" version = "0.7.2" @@ -1987,6 +2009,83 @@ files = [ {file = "jupyterlab_widgets-3.0.16.tar.gz", hash = "sha256:423da05071d55cf27a9e602216d35a3a65a3e41cdf9c5d3b643b814ce38c19e0"}, ] +[[package]] +name = "k-means-constrained" +version = "0.7.3" +description = "K-Means clustering constrained with minimum and maximum cluster size" +optional = true +python-versions = ">=3.8" +files = [ + {file = "k-means-constrained-0.7.3.tar.gz", hash = "sha256:2b5bf8317614d749cf3da6b6f4adf52dee34cdf0dc16b8a26b40d72f1a000dff"}, + {file = "k_means_constrained-0.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8d8304b1688d16d31b2803a94504484524ee8977fa2eaac53e9a33a4f1ef0c9e"}, + {file = "k_means_constrained-0.7.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5b1d79e9032c0471f522a964fa1804f776a980fc932c2793d0d256d33b762119"}, + {file = "k_means_constrained-0.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68e1736e9b8a6dfd2af0de5200ef03fc5b13ef55907e3f5431289ee519d704ae"}, + {file = "k_means_constrained-0.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:15d838a941043fe6480ff36f7d8e9128fb73f193248c42b6fdbaf2aa5b8ce00e"}, + {file = "k_means_constrained-0.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:521881df8362dce22875e431150f74a4f6f17cf798bbe2f3711d145f97197335"}, + {file = "k_means_constrained-0.7.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:04e07d7fecc36af499f13efd0925712a7dc8382fb10f5582d780f3c7a2304c71"}, + {file = "k_means_constrained-0.7.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:2f5034fa54879ff7e15882e3d4cfe572e24549447717a6fcf724bbe1520e06fe"}, + {file = "k_means_constrained-0.7.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:08288fe3011783a1ffd0c18cf314ea604e9d1ba68d601ab17584428097c15ad4"}, + {file = "k_means_constrained-0.7.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b279f4807314bbeabf5bd5fb78de382607d8e97b31bb277d9a09d9a8b412b969"}, + {file = "k_means_constrained-0.7.3-cp311-cp311-win_amd64.whl", hash = "sha256:d6419494fdfbad3448f2c6ab8fb9492cadaca723a3b3d1a7acbacef894095a17"}, + {file = "k_means_constrained-0.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:f229dc1791a9d792b68b3c79555dc7d94462b56fd2ffc24914cb3c115a12dbb6"}, + {file = "k_means_constrained-0.7.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d116ad48100c6ee3f57ee4c675792a4a47569dbe817c25899aa29bcdedb184aa"}, + {file = "k_means_constrained-0.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:05b6b84ba06a79c3d50e787770a7ba02f5ba50df3f2a3dc1a106e5155de656b8"}, + {file = "k_means_constrained-0.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da68417759907b0abf127b876f2b92264a0d6fc056165d8080c1557dd79e339a"}, + {file = "k_means_constrained-0.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:98beab00776bd887f4fe28322b1183f85da3d4e2e2d8f2960d8d9edf21412a3b"}, + {file = "k_means_constrained-0.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e4fb34b8ac2c9857505f8d34b8f77e2dd21e2cba28e811f4bb9a7e37844ba6c7"}, + {file = "k_means_constrained-0.7.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3965bc53eddb00777fc0fd34bc4332fde29c283c93ac82326a3678b175e180e5"}, + {file = "k_means_constrained-0.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1354253318ed7c688ac6f53d6e29af8065e5a64965bb535279cd6e40e66c63af"}, + {file = "k_means_constrained-0.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26c07b374ed4dda4e396442e4661aed6442294858c8cf38715a0b44f816da590"}, + {file = "k_means_constrained-0.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:a4ffd8222d09f392c2c8822f489093c509f2c0d10e10a0d6f0e42a0abff4061d"}, +] + +[package.dependencies] +joblib = "*" +numpy = ">=1.23.0" +ortools = ">=9.4.1874" +scipy = ">=1.6.3" +six = "*" + +[package.extras] +dev = ["bump2version", "nose", "numpydoc", "pandas (>=1.0.4)", "pytest (>=5.1)", "scikit-learn (>=0.24.2)", "sphinx", "sphinx-rtd-theme", "twine"] +docs = ["sphinx", "sphinx-rtd-theme"] + +[[package]] +name = "k-means-constrained" +version = "0.7.6" +description = "K-Means clustering constrained with minimum and maximum cluster size" +optional = true +python-versions = "*" +files = [ + {file = "k_means_constrained-0.7.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:fe6dc2acfeafda2e7bdb0c3d290dbb0b4b57078ad47990a1ba783a26ed374286"}, + {file = "k_means_constrained-0.7.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6fe7b59e45b8ef21720e16a4b75c69e242c5b447bc4ce230c0c33b9caf91f87b"}, + {file = "k_means_constrained-0.7.6-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3629dfe2fa0629cc1ab8864f84e1b19651ab272e4016af0b95791215d0573692"}, + {file = "k_means_constrained-0.7.6-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5fb57b1e6c4c2bdc2f80b9c993f97345bfa8987269da4ca50e860acd93f18126"}, + {file = "k_means_constrained-0.7.6-cp310-cp310-win_amd64.whl", hash = "sha256:ad191818fe2c2a216ac1bc8a69790d3654cffaa88797e2f48c67f3f5c5315b39"}, + {file = "k_means_constrained-0.7.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c841ea6629786697bb486ceeac0bec59d016794b468110ca6d521446ba12c995"}, + {file = "k_means_constrained-0.7.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ef667d9191321cef7ef4ef2132c82c017b44e8fb334a91e975167024685b89be"}, + {file = "k_means_constrained-0.7.6-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eaaa77929eb6728beb16a2bd755104eee8068d807a1cbfce9a8b206f0cb3b2e2"}, + {file = "k_means_constrained-0.7.6-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:28425e8f70d6035d482ff00a661c69c7a3e400403583916d596306daf3e79199"}, + {file = "k_means_constrained-0.7.6-cp311-cp311-win_amd64.whl", hash = "sha256:a5783224f20ef4703dba4b21a2d0dc22e3b314a6ed5c064f6767021bd932a93b"}, + {file = "k_means_constrained-0.7.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:05734675db6d5d43888394690640d283a2c1d5529435ecba87e172dffc806189"}, + {file = "k_means_constrained-0.7.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:34e08d5b707905df635ba1145e639269745c56c347157717265c3e48c2a04ba6"}, + {file = "k_means_constrained-0.7.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9eb967fbf28b6bd91fdfed3bc543fae4973d8a474a3aa6dabd141236ebb36682"}, + {file = "k_means_constrained-0.7.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:736d97ff642dfaaf0df4b63632b827ba4b1466794e111970fbfc99e2c52286a2"}, + {file = "k_means_constrained-0.7.6-cp312-cp312-win_amd64.whl", hash = "sha256:599eb9e22e2c99b50dd7d9bf434789de7c0831972c508789be40a25b1796469b"}, + {file = "k_means_constrained-0.7.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4cdf3de34fb7f24a78af05e3eaec063747258b8f71513faf85b72e9f601bc351"}, + {file = "k_means_constrained-0.7.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:24ae15921827e655cfc13e6c9138e7f6b9be3bc488b2030eb19461de66aff42e"}, + {file = "k_means_constrained-0.7.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2e2d9842493bf934eafadb68b689b715e1b2880ba58af05263ee2e11ccb3c6d7"}, + {file = "k_means_constrained-0.7.6-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:876391bb2f1a79ece65d1dba8a76545da09b9ee4cb287f467e861c1043caaf4f"}, + {file = "k_means_constrained-0.7.6-cp313-cp313-win_amd64.whl", hash = "sha256:a828bbc81dde339e7ae97934bf9dcb43c5b49884d1d9c8fb0ffa117f5c998233"}, +] + +[package.dependencies] +joblib = "*" +numpy = ">=2.1.1" +ortools = ">=9.11.4210" +scipy = ">=1.14.1" +six = "*" + [[package]] name = "kiwisolver" version = "1.4.7" @@ -3332,6 +3431,113 @@ files = [ {file = "nvidia_nvtx_cu12-12.8.90-py3-none-win_amd64.whl", hash = "sha256:619c8304aedc69f02ea82dd244541a83c3d9d40993381b3b590f1adaed3db41e"}, ] +[[package]] +name = "ortools" +version = "9.14.6206" +description = "Google OR-Tools python libraries and modules" +optional = true +python-versions = ">=3.9" +files = [ + {file = "ortools-9.14.6206-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:6e2364edd1577cd094e7c7121ec5fb0aa462a69a78ce29cdc40fa45943ff0091"}, + {file = "ortools-9.14.6206-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:164b726b4d358ae68a018a52ff1999c0646d6f861b33676c2c83e2ddb60cfa13"}, + {file = "ortools-9.14.6206-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ebb0e210969cc3246fe78dadf9038936a3a18edc8156e23a394e2bbcec962431"}, + {file = "ortools-9.14.6206-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:174de2f04c106c7dcc5989560f2c0e065e78fba0ad0d1fd029897582f4823c3a"}, + {file = "ortools-9.14.6206-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:e6d994ebcf9cbdda1e20a75662967124e7e6ffd707c7f60b2db1a11f2104d384"}, + {file = "ortools-9.14.6206-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:5763472f8b05072c96c36c4eafadd9f6ffcdab38a81d8f0142fc408ad52a4342"}, + {file = "ortools-9.14.6206-cp310-cp310-win_amd64.whl", hash = "sha256:6711516f837f06836ff9fda66fe4337b88c214f2ba6a921b84d3b05876f1fa8c"}, + {file = "ortools-9.14.6206-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:8bcd8481846090585a4fac82800683555841685c49fa24578ad1e48a37918568"}, + {file = "ortools-9.14.6206-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5af2bbf2fff7d922ba036e27d7ff378abecb24749380c86a77fa6208d5ba35cd"}, + {file = "ortools-9.14.6206-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a6ab43490583c4bbf0fff4e51bb1c15675d5651c2e8e12ba974fd08e8c05a48f"}, + {file = "ortools-9.14.6206-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9aa2c0c50a765c6a060960dcb0207bd6aeb6341f5adacb3d33e613b7e7409428"}, + {file = "ortools-9.14.6206-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:64ec63fd92125499e9ca6b72700406dda161eefdfef92f04c35c5150391f89a4"}, + {file = "ortools-9.14.6206-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8651008f05257471f45a919ade5027afa12ab6f7a4fdf0a8bcc18c92032f8571"}, + {file = "ortools-9.14.6206-cp311-cp311-win_amd64.whl", hash = "sha256:ca60877830a631545234e83e7f6bd55830334a4d0c2b51f1669b1f2698d58b84"}, + {file = "ortools-9.14.6206-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:e38c8c4a184820cbfdb812a8d484f6506cf16993ce2a95c88bc1c9d23b17c63e"}, + {file = "ortools-9.14.6206-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:db685073cbed9f8bfaa744f5e883f3dea57c93179b0abe1788276fd3b074fa61"}, + {file = "ortools-9.14.6206-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d4bfb8bffb29991834cf4bde7048ca8ee8caed73e8dd21e5ec7de99a33bbfea0"}, + {file = "ortools-9.14.6206-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eb464a698837e7f90ca5f9b3d748b6ddf553198a70032bc77824d1cd88695d2b"}, + {file = "ortools-9.14.6206-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8f33deaeb7c3dda8ca1d29c5b9aa9c3a4f2ca9ecf34f12a1f809bb2995f41274"}, + {file = "ortools-9.14.6206-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:086e7c2dc4f23efffb20a5e20f618c7d6adb99b2d94f684cab482387da3bc434"}, + {file = "ortools-9.14.6206-cp312-cp312-win_amd64.whl", hash = "sha256:17c13b0bfde17ac57789ad35243edf1318ecd5db23cf949b75ab62480599f188"}, + {file = "ortools-9.14.6206-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:8d0df7eef8ba53ad235e29018389259bad2e667d9594b9c2a412ed6a5756bd4e"}, + {file = "ortools-9.14.6206-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:57dfe10844ce8331634d4723040fe249263fd490407346efc314c0bc656849b5"}, + {file = "ortools-9.14.6206-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5c0c2c00a6e5d5c462e76fdda7dbd40d0f9139f1df4211d34b36906696248020"}, + {file = "ortools-9.14.6206-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:38044cf39952d93cbcc02f6acdbe0a9bd3628fbf17f0d7eb0374060fa028c22e"}, + {file = "ortools-9.14.6206-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:98564de773d709e1e49cb3c32f6917589c314f047786d88bd5f324c0eb7be96e"}, + {file = "ortools-9.14.6206-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:80528b0ac72dc3de00cbeef2ce028517a476450b5877b1cda1b8ecb9fa98505e"}, + {file = "ortools-9.14.6206-cp313-cp313-win_amd64.whl", hash = "sha256:47b1b15dcb085d32c61621b790259193aefa9e4577abadf233d47fbe7d0b81ef"}, + {file = "ortools-9.14.6206-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d26a0f9ed97ef9d3384a9069923585f5f974c3fde555a41f4d6381fbe7840bc4"}, + {file = "ortools-9.14.6206-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d40d8141667d47405f296a9f687058c566d7816586e9a672b59e9fcec8493133"}, + {file = "ortools-9.14.6206-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:aefea81ed81aa937873efc520381785ed65380e52917f492ab566f46bbb5660d"}, + {file = "ortools-9.14.6206-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:f044bb277db3ab6a1b958728fe1cf14ca87c3800d67d7b321d876b48269340f6"}, + {file = "ortools-9.14.6206-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:070dc7cebfa0df066acb6b9a6d02339351be8f91b2352b782ee7f40412207e20"}, + {file = "ortools-9.14.6206-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5eb558a03b4ada501ecdea7b89f0d3bdf2cc6752e1728759ccf27923f592a8c2"}, + {file = "ortools-9.14.6206-cp39-cp39-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:646329fa74a5c48c591b7fabfd26743f6d2de4e632b3b96ec596c47bfe19177a"}, + {file = "ortools-9.14.6206-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:aa5161924f35b8244295acd0fab2a8171bb08ef8d5cfaf1913a21274475704cc"}, + {file = "ortools-9.14.6206-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:e253526a026ae194aed544a0d065163f52a0c9cb606a1061c62df546877d5452"}, + {file = "ortools-9.14.6206-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:dcb496ef633d884036770783f43bf8a47ff253ecdd8a8f5b95f00276ec241bfd"}, + {file = "ortools-9.14.6206-cp39-cp39-win_amd64.whl", hash = "sha256:2733f635675de631fdc7b1611878ec9ee2f48a26434b7b3c07d0a0f535b92e03"}, +] + +[package.dependencies] +absl-py = ">=2.0.0" +immutabledict = ">=3.0.0" +numpy = ">=1.13.3" +pandas = ">=2.0.0" +protobuf = ">=6.31.1,<6.32" +typing-extensions = ">=4.12" + +[[package]] +name = "ortools" +version = "9.15.6755" +description = "Google OR-Tools python libraries and modules" +optional = true +python-versions = ">=3.9" +files = [ + {file = "ortools-9.15.6755-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:e4559603031ed371c5d86b1e9357fa49fb89236452e4b9bc429a0cf4a2fab05d"}, + {file = "ortools-9.15.6755-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5bb2b434f4ae01ce81813d01db722d9dedcc452aede681211ee4d4df8963a410"}, + {file = "ortools-9.15.6755-cp310-cp310-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0b26655d25ab28030aef30e675e24d96d35940974de3a70ace01cf82ca301b69"}, + {file = "ortools-9.15.6755-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03424136aa48555e7f4d1bc73edeb99f80ec35a2e5f17700e2072640344980fc"}, + {file = "ortools-9.15.6755-cp310-cp310-win_amd64.whl", hash = "sha256:4f4964f8ed47ac76b5cfd23238618299f5a3c289d8e0ed66a75885ba9766eb6f"}, + {file = "ortools-9.15.6755-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:55e291560d2fdb9590656cbee06ba99ee7f2476bd7d316ff757eeab33e9b20d6"}, + {file = "ortools-9.15.6755-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e51ae55569650e5381fd6e50c655ccf6368a9532f5720ea41396bb90e0247a21"}, + {file = "ortools-9.15.6755-cp311-cp311-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c3bcccd15ef3fc6ac10bfa11630ba6dfe437d4fd1374a5b33f4773b7fee0f877"}, + {file = "ortools-9.15.6755-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7a85a68ffb3fc1967e78624f40d3707aae459e82e3f2d9fe02e91788b3c7bf2a"}, + {file = "ortools-9.15.6755-cp311-cp311-win_amd64.whl", hash = "sha256:781fb09d6c9f46015291f706bd7c7e0815db1bec6e92c74716342fb7ea2d0532"}, + {file = "ortools-9.15.6755-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:ae1c6e1fd844b4d756b22eb6c0ed574ea4342ee206d807c4f903039e748228fa"}, + {file = "ortools-9.15.6755-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e16686c2b457fa6242c474ab890ee1712347ab53678e0d2fab307ae03e97a4b"}, + {file = "ortools-9.15.6755-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3cd6bec0a2e00e3891a53e3b436f45a1000269f302085572f49e9856b7f8eaf0"}, + {file = "ortools-9.15.6755-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:033836c0eb33bc72697a299e0caedbb25fc9d1cee0b13832d69cb30405f57b3e"}, + {file = "ortools-9.15.6755-cp312-cp312-win_amd64.whl", hash = "sha256:487796301fd9dad55f9cf21f9313c834697f74306d1a59f002e152862f8eb1b5"}, + {file = "ortools-9.15.6755-cp313-cp313-macosx_10_15_x86_64.whl", hash = "sha256:27a10474e62c9dceed37cfa0e4845c5ffaf792138ebf5b61483771b96f1290b6"}, + {file = "ortools-9.15.6755-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:076565b803c85c4f87863e0616f537dd37f99c03e6f092e4068404f7b425d2b0"}, + {file = "ortools-9.15.6755-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b85bd20259b146abce5e0721ce1bfd8fd273efc904216aa3be178c31b6d34057"}, + {file = "ortools-9.15.6755-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ebd5aea00374e3aad7a78de59058aca5e871a26a3c385cd0860ef1d685d03c9a"}, + {file = "ortools-9.15.6755-cp313-cp313-win_amd64.whl", hash = "sha256:caac1d48b967adb877da2abcaf82c28f0f908a7cc208a6a1bbe01bc69590816c"}, + {file = "ortools-9.15.6755-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:82b4a8e6e4f9380b453ab5fa4382ea7ee91e628f9b8be89d9ad760b33fca3323"}, + {file = "ortools-9.15.6755-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2d1f2fb2088e8953ccb902e68ffd06032cce0c7dcf7268b6135f3b6c553ca52b"}, + {file = "ortools-9.15.6755-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:acdf06a167933307608e7eba23a9490255933504df44c8de5f62c48656c29688"}, + {file = "ortools-9.15.6755-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1a0677270b0cd317a6b8dae42514264eaf5da5756c5bc7215eeea409424577df"}, + {file = "ortools-9.15.6755-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:899b92afe3f775ab5867b9a8aa2850f81f2d95232db9b4ceec3456d69e6b8528"}, + {file = "ortools-9.15.6755-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7181183cdcafe2b0d83ca5505b65048c7953dc7b5ad479361dded607964cc1b3"}, + {file = "ortools-9.15.6755-cp314-cp314-win_amd64.whl", hash = "sha256:afabb869e5fabeb704bd8147b22bf8139dee042e55fabd0d447a996428009e0c"}, + {file = "ortools-9.15.6755-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d9d07cddca201e25e2e219006a9d6cda10c7e9ee2c712c50d19d508f9ed8a888"}, + {file = "ortools-9.15.6755-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:990838ad66a052e72a50e69da500878710e3420e91717fe88bf3071995caba9e"}, + {file = "ortools-9.15.6755-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:73b229dbc2b225441cb3bc5790ea8d55f14e3cd7f32d5185784f60b102308457"}, + {file = "ortools-9.15.6755-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8881e9620bf0bf8303891e171feb03d6e86c75a05fb9325a09ae7fbf93093f4a"}, + {file = "ortools-9.15.6755-cp39-cp39-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:87c73acda29f03ded74c7d2388f6efcbe45fa45a3f2bae4d85e1b5f1cc4cd9c1"}, + {file = "ortools-9.15.6755-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:76150c4dd5927d0ab138344a5121b1f63e32501371ca93e632e2e10d8261064b"}, + {file = "ortools-9.15.6755-cp39-cp39-win_amd64.whl", hash = "sha256:d72c136fd6e4b112bf154680290490da0aca60b7ddfc4163581c069c67016d4a"}, +] + +[package.dependencies] +absl-py = ">=2.0.0" +immutabledict = ">=3.0.0" +numpy = ">=2.0.2" +pandas = ">=2.0.0" +protobuf = ">=6.33.1,<6.34" +typing-extensions = ">=4.12" + [[package]] name = "packaging" version = "26.0" @@ -3874,6 +4080,43 @@ files = [ {file = "propcache-0.4.1.tar.gz", hash = "sha256:f48107a8c637e80362555f37ecf49abe20370e557cc4ab374f04ec4423c97c3d"}, ] +[[package]] +name = "protobuf" +version = "6.31.1" +description = "" +optional = true +python-versions = ">=3.9" +files = [ + {file = "protobuf-6.31.1-cp310-abi3-win32.whl", hash = "sha256:7fa17d5a29c2e04b7d90e5e32388b8bfd0e7107cd8e616feef7ed3fa6bdab5c9"}, + {file = "protobuf-6.31.1-cp310-abi3-win_amd64.whl", hash = "sha256:426f59d2964864a1a366254fa703b8632dcec0790d8862d30034d8245e1cd447"}, + {file = "protobuf-6.31.1-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:6f1227473dc43d44ed644425268eb7c2e488ae245d51c6866d19fe158e207402"}, + {file = "protobuf-6.31.1-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:a40fc12b84c154884d7d4c4ebd675d5b3b5283e155f324049ae396b95ddebc39"}, + {file = "protobuf-6.31.1-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:4ee898bf66f7a8b0bd21bce523814e6fbd8c6add948045ce958b73af7e8878c6"}, + {file = "protobuf-6.31.1-cp39-cp39-win32.whl", hash = "sha256:0414e3aa5a5f3ff423828e1e6a6e907d6c65c1d5b7e6e975793d5590bdeecc16"}, + {file = "protobuf-6.31.1-cp39-cp39-win_amd64.whl", hash = "sha256:8764cf4587791e7564051b35524b72844f845ad0bb011704c3736cce762d8fe9"}, + {file = "protobuf-6.31.1-py3-none-any.whl", hash = "sha256:720a6c7e6b77288b85063569baae8536671b39f15cc22037ec7045658d80489e"}, + {file = "protobuf-6.31.1.tar.gz", hash = "sha256:d8cac4c982f0b957a4dc73a80e2ea24fab08e679c0de9deb835f4a12d69aca9a"}, +] + +[[package]] +name = "protobuf" +version = "6.33.6" +description = "" +optional = true +python-versions = ">=3.9" +files = [ + {file = "protobuf-6.33.6-cp310-abi3-win32.whl", hash = "sha256:7d29d9b65f8afef196f8334e80d6bc1d5d4adedb449971fefd3723824e6e77d3"}, + {file = "protobuf-6.33.6-cp310-abi3-win_amd64.whl", hash = "sha256:0cd27b587afca21b7cfa59a74dcbd48a50f0a6400cfb59391340ad729d91d326"}, + {file = "protobuf-6.33.6-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:9720e6961b251bde64edfdab7d500725a2af5280f3f4c87e57c0208376aa8c3a"}, + {file = "protobuf-6.33.6-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e2afbae9b8e1825e3529f88d514754e094278bb95eadc0e199751cdd9a2e82a2"}, + {file = "protobuf-6.33.6-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c96c37eec15086b79762ed265d59ab204dabc53056e3443e702d2681f4b39ce3"}, + {file = "protobuf-6.33.6-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:e9db7e292e0ab79dd108d7f1a94fe31601ce1ee3f7b79e0692043423020b0593"}, + {file = "protobuf-6.33.6-cp39-cp39-win32.whl", hash = "sha256:bd56799fb262994b2c2faa1799693c95cc2e22c62f56fb43af311cae45d26f0e"}, + {file = "protobuf-6.33.6-cp39-cp39-win_amd64.whl", hash = "sha256:f443a394af5ed23672bc6c486be138628fbe5c651ccbc536873d7da23d1868cf"}, + {file = "protobuf-6.33.6-py3-none-any.whl", hash = "sha256:77179e006c476e69bf8e8ce866640091ec42e1beb80b213c3900006ecfba6901"}, + {file = "protobuf-6.33.6.tar.gz", hash = "sha256:a6768d25248312c297558af96a9f9c929e8c4cee0659cb07e780731095f38135"}, +] + [[package]] name = "psutil" version = "7.2.2" @@ -5480,10 +5723,17 @@ description = "Tensors and Dynamic neural networks in Python with strong GPU acc optional = true python-versions = ">=3.10" files = [ - {file = "torch-2.10.0-1-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:c37fc46eedd9175f9c81814cc47308f1b42cfe4987e532d4b423d23852f2bf63"}, - {file = "torch-2.10.0-1-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:f699f31a236a677b3118bc0a3ef3d89c0c29b5ec0b20f4c4bf0b110378487464"}, - {file = "torch-2.10.0-1-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:6abb224c2b6e9e27b592a1c0015c33a504b00a0e0938f1499f7f514e9b7bfb5c"}, - {file = "torch-2.10.0-1-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:7350f6652dfd761f11f9ecb590bfe95b573e2961f7a242eccb3c8e78348d26fe"}, + {file = "torch-2.10.0-2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:2b980edd8d7c0a68c4e951ee1856334a43193f98730d97408fbd148c1a933313"}, + {file = "torch-2.10.0-2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:418997cb02d0a0f1497cf6a09f63166f9f5df9f3e16c8a716ab76a72127c714f"}, + {file = "torch-2.10.0-2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:13ec4add8c3faaed8d13e0574f5cd4a323c11655546f91fbe6afa77b57423574"}, + {file = "torch-2.10.0-2-cp313-none-macosx_11_0_arm64.whl", hash = "sha256:e521c9f030a3774ed770a9c011751fb47c4d12029a3d6522116e48431f2ff89e"}, + {file = "torch-2.10.0-3-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a1ff626b884f8c4e897c4c33782bdacdff842a165fee79817b1dd549fdda1321"}, + {file = "torch-2.10.0-3-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ac5bdcbb074384c66fa160c15b1ead77839e3fe7ed117d667249afce0acabfac"}, + {file = "torch-2.10.0-3-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:98c01b8bb5e3240426dcde1446eed6f40c778091c8544767ef1168fc663a05a6"}, + {file = "torch-2.10.0-3-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:80b1b5bfe38eb0e9f5ff09f206dcac0a87aadd084230d4a36eea5ec5232c115b"}, + {file = "torch-2.10.0-3-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:46b3574d93a2a8134b3f5475cfb98e2eb46771794c57015f6ad1fb795ec25e49"}, + {file = "torch-2.10.0-3-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:b1d5e2aba4eb7f8e87fbe04f86442887f9167a35f092afe4c237dfcaaef6e328"}, + {file = "torch-2.10.0-3-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:0228d20b06701c05a8f978357f657817a4a63984b0c90745def81c18aedfa591"}, {file = "torch-2.10.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:5276fa790a666ee8becaffff8acb711922252521b28fbce5db7db5cf9cb2026d"}, {file = "torch-2.10.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:aaf663927bcd490ae971469a624c322202a2a1e68936eb952535ca4cd3b90444"}, {file = "torch-2.10.0-cp310-cp310-win_amd64.whl", hash = "sha256:a4be6a2a190b32ff5c8002a0977a25ea60e64f7ba46b1be37093c141d9c49aeb"}, @@ -5972,9 +6222,10 @@ test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_it type = ["pytest-mypy"] [extras] -all = ["catboost", "cupy-cuda12x", "cupy-cuda12x", "ipywidgets", "nbformat", "nmslib", "nmslib-metabrainz", "plotly", "pytorch-lightning", "pytorch-lightning", "rectools-lightfm", "torch", "torch", "torch"] +all = ["catboost", "cupy-cuda12x", "cupy-cuda12x", "ipywidgets", "k-means-constrained", "nbformat", "nmslib", "nmslib-metabrainz", "plotly", "pytorch-lightning", "pytorch-lightning", "rectools-lightfm", "torch", "torch", "torch"] catboost = ["catboost"] cupy = ["cupy-cuda12x", "cupy-cuda12x"] +k-means-constrained = ["k-means-constrained"] lightfm = ["rectools-lightfm"] nmslib = ["nmslib", "nmslib-metabrainz"] torch = ["pytorch-lightning", "pytorch-lightning", "torch", "torch", "torch"] @@ -5983,4 +6234,4 @@ visuals = ["ipywidgets", "nbformat", "plotly"] [metadata] lock-version = "2.0" python-versions = ">=3.9, <3.14" -content-hash = "16e669449fa37e7bbf6ec621fb42bad756aa55b77d503b26839e3d388b30fd0a" +content-hash = "91bfbe4a6a0aeb25c5d7e90c999546db562f39b17501d74cd022a8475b0192c4" diff --git a/pyproject.toml b/pyproject.toml index 8423c7ef..042b85a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,7 @@ cupy-cuda12x = [ # and cupy-cuda12x, so we add the version restriction here manually to avoid # installing older version of fastrlock which is incompatible with Python 3.13 fastrlock = {version = "^0.8.3", markers = "sys_platform != 'darwin'", optional = true} +k-means-constrained = {version = "^0.7.0", optional = true} [tool.poetry.extras] lightfm = ["rectools-lightfm"] @@ -125,6 +126,7 @@ torch = ["torch", "pytorch-lightning"] visuals = ["ipywidgets", "plotly", "nbformat"] cupy = ["cupy-cuda12x"] catboost = ["catboost"] +k-means-constrained = ["k-means-constrained"] all = [ "rectools-lightfm", "nmslib", "nmslib-metabrainz", @@ -132,6 +134,7 @@ all = [ "catboost", "ipywidgets", "plotly", "nbformat", "cupy-cuda12x", + "k-means-constrained", ] diff --git a/rectools/semantic/__init__.py b/rectools/semantic/__init__.py new file mode 100644 index 00000000..4ac0f2ae --- /dev/null +++ b/rectools/semantic/__init__.py @@ -0,0 +1,9 @@ +"""rectools.semantic -- generative recommenders with Semantic ID inputs""" + +from .config_options import LRScheduleType, OptimizerType, QuantizerType + +__all__ = [ + "LRScheduleType", + "OptimizerType", + "QuantizerType", +] diff --git a/rectools/semantic/config_options.py b/rectools/semantic/config_options.py new file mode 100644 index 00000000..77c64f23 --- /dev/null +++ b/rectools/semantic/config_options.py @@ -0,0 +1,5 @@ +import typing as tp + +OptimizerType = tp.Literal["adam", "adagrad", "adamw"] +QuantizerType = tp.Literal["rqvae", "rkmeans"] +LRScheduleType = tp.Literal["constant", "cosine", "linear"] diff --git a/rectools/semantic/data_handling/__init__.py b/rectools/semantic/data_handling/__init__.py new file mode 100644 index 00000000..4d6adc80 --- /dev/null +++ b/rectools/semantic/data_handling/__init__.py @@ -0,0 +1,12 @@ +"""Vectorized sequence preprocessing for TIGER generative recommender""" + +from .dataset import PaddingCollateFn, TIGERDataset +from .k_core import k_core +from .loo_split import loo_split + +__all__ = [ + "PaddingCollateFn", + "TIGERDataset", + "k_core", + "loo_split", +] diff --git a/rectools/semantic/data_handling/dataset.py b/rectools/semantic/data_handling/dataset.py new file mode 100644 index 00000000..ae96fbb0 --- /dev/null +++ b/rectools/semantic/data_handling/dataset.py @@ -0,0 +1,194 @@ +import typing as tp + +import numpy as np +import pandas as pd +import torch +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset as TorchDataset +from tqdm.auto import tqdm + +from rectools.semantic.tokenizer import SIDTokenizer + + +class TIGERDataset(TorchDataset): # pylint: disable=too-many-instance-attributes + """Dataset for the TIGER generative retrieval model. + + Accepts a pandas DataFrame of interactions, builds per-user item + sequences, and tokenizes them into Semantic IDs. + + In train mode (``eval_mode=False``), each sample is a + (history, next-item) pair: + - ``input_ids`` : 1-D flattened encoder token sequence + - ``dec_input`` : decoder input with BOS prepended + - ``labels`` : raw SID codes of the target item + + In eval mode (``eval_mode=True``), each sample is: + - ``input_ids`` : 1-D flattened encoder token sequence (all items except the last) + - ``labels`` : raw item ID of the target (last item) + + Parameters + ---------- + interactions : pd.DataFrame + Interaction history with at least ``user_col`` and ``item_col`` columns. + tokenizer : SIDTokenizer + Trained tokenizer for converting item IDs to Semantic IDs. + codebook_sizes : list of int + Number of codes per RQ-VAE level. + max_length : int + Maximum number of history items per sequence. + codeword_offset : int + Offset added to codeword indices to reserve special tokens. + bos_token_id : int + Token ID used as beginning-of-sequence in the decoder. + only_last : bool + If True, only the last ``max_length + 1`` items per user are used. + If False, all subsequences of length >= 2 are generated. + eval_mode : bool + If True, the dataset produces evaluation samples (raw item ID as label). + max_users : int or None + If set, subsample to at most this many users. + user_col : str + Name of the user ID column. + item_col : str + Name of the item ID column. + timestamp_col : str + Name of the timestamp column. If present in the DataFrame, + interactions are sorted by (user_col, timestamp_col). + """ + + def __init__( + self, + interactions: pd.DataFrame, + tokenizer: SIDTokenizer, + codebook_sizes: tp.List[int], + max_length: int = 20, + codeword_offset: int = 2, + bos_token_id: int = 1, + only_last: bool = True, + eval_mode: bool = False, + max_users: tp.Optional[int] = None, + user_col: str = "user_id", + item_col: str = "item_id", + timestamp_col: str = "timestamp", + ) -> None: + self.tokenizer = tokenizer + self.codebook_sizes = codebook_sizes + self.sid_len = len(codebook_sizes) + self.max_length = max_length + self.codeword_offset = codeword_offset + self.bos_token_id = bos_token_id + self.only_last = only_last + self.eval_mode = eval_mode + self.max_users = max_users + self.user_col = user_col + self.item_col = item_col + self.timestamp_col = timestamp_col + + self.interactions = interactions + + self.offsets = np.cumsum([codeword_offset] + codebook_sizes)[:-1] + + df = interactions + if timestamp_col in df.columns: + df = df.sort_values([user_col, timestamp_col]) + + user_sequences = df.groupby(user_col)[item_col].agg(list) + + self._create_sequences(user_sequences) + + def _create_sequences(self, user_sequences: pd.DataFrame) -> None: # pylint: disable=too-many-branches + user_ids = user_sequences.index.tolist() + if self.max_users is not None and len(user_ids) > self.max_users: + user_ids = np.random.choice(user_ids, size=self.max_users, replace=False).tolist() + user_sequences = user_sequences.loc[user_ids] + + if self.eval_mode: + self.sequences: tp.List[tp.List[tp.Tuple[int, ...]]] = [] + self.targets: tp.List[int] = [] + for seq in user_sequences: + if len(seq) > self.max_length + 1: + tokenized = self.tokenizer.tokenize(seq[-self.max_length - 1 : -1]) + else: + tokenized = self.tokenizer.tokenize(seq[:-1]) + assert isinstance(tokenized, list) + self.sequences.append(tokenized) + self.targets.append(seq[-1]) + elif self.only_last: + self.sequences = [] + for seq in user_sequences: + seq = seq[-self.max_length - 1 :] if len(seq) > self.max_length + 1 else seq + tokenized = self.tokenizer.tokenize(seq) + assert isinstance(tokenized, list) + self.sequences.append(tokenized) + else: + self.sequences = [] + for seq in tqdm(user_sequences, desc="Preparing the train data", ncols=120): + if len(seq) < 2: + continue + for t in range(1, len(seq)): + start = max(0, t - self.max_length) + tokenized = self.tokenizer.tokenize(seq[start : t + 1]) + assert isinstance(tokenized, list) + self.sequences.append(tokenized) + + def __len__(self) -> int: + return len(self.sequences) + + def _sid_to_tokens(self, sid: tp.Tuple[int, ...]) -> tp.List[int]: + return [code + self.offsets[d] for d, code in enumerate(sid)] + + def __getitem__(self, idx: int) -> tp.Dict[str, tp.Any]: + if self.eval_mode: + history_sids: tp.List[tp.Tuple[int, ...]] = self.sequences[idx] + target = self.targets[idx] + + input_ids = np.array( + [tok for sid in history_sids for tok in self._sid_to_tokens(sid)], + dtype=np.int64, + ) + return {"input_ids": input_ids, "labels": target} + + item_sequence: tp.List[tp.Tuple[int, ...]] = self.sequences[idx] + history_sids = item_sequence[:-1] + target_sid: tp.Tuple[int, ...] = item_sequence[-1] + + input_ids = np.array( + [tok for sid in history_sids for tok in self._sid_to_tokens(sid)], + dtype=np.int64, + ) + + target_tokens = np.array(self._sid_to_tokens(target_sid), dtype=np.int64) + + dec_input = np.concatenate([[self.bos_token_id], target_tokens[:-1]]).astype(np.int64) + + labels = np.array(target_sid, dtype=np.int64) + + return {"input_ids": input_ids, "dec_input": dec_input, "labels": labels} + + +class PaddingCollateFn: + """Automatically right pad user interaction sequences and labels with specified padding values. + + Parameters + ---------- + padding_value : int, optional + Value to pad input sequences with, by default 0 + labels_padding_value : int, optional + Value to pad labels with, by default -100 + """ + + def __init__(self, padding_value: int = 0, labels_padding_value: int = -100) -> None: + self.padding_value = padding_value + self.labels_padding_value = labels_padding_value + + def __call__(self, batch: tp.List[tp.Dict[str, tp.Any]]) -> tp.Dict[str, torch.Tensor]: + """Apply padding to all fields in a batch.""" + collated = {} + for key in batch[0]: + if np.isscalar(batch[0][key]): + collated[key] = torch.tensor([ex[key] for ex in batch]) + continue + pad_val = self.labels_padding_value if key == "labels" else self.padding_value + values = [torch.tensor(ex[key]) for ex in batch] + collated[key] = pad_sequence(values, batch_first=True, padding_value=pad_val) + return collated diff --git a/rectools/semantic/data_handling/k_core.py b/rectools/semantic/data_handling/k_core.py new file mode 100644 index 00000000..15401f4c --- /dev/null +++ b/rectools/semantic/data_handling/k_core.py @@ -0,0 +1,48 @@ +import pandas as pd + + +def k_core( + interactions: pd.DataFrame, + user_min_interactions: int = 5, + item_min_interactions: int = 5, + user_col: str = "user_id", + item_col: str = "item_id", +) -> pd.DataFrame: + """Filter user and item interactions using iterative k-core algorithm. + + Removes users with fewer than ``user_min_interactions`` interactions + and items with fewer than ``item_min_interactions`` interactions, + repeating until no more removals are needed. + + Parameters + ---------- + interactions : pd.DataFrame + Interaction history. Must contain ``user_col`` and ``item_col`` columns. + user_min_interactions : int + Minimum number of interactions per user. + item_min_interactions : int + Minimum number of interactions per item. + user_col : str + Name of the user ID column. + item_col : str + Name of the item ID column. + + Returns + ------- + pd.DataFrame + Filtered interactions. + """ + df = interactions + while True: + user_counts = df[user_col].value_counts() + item_counts = df[item_col].value_counts() + bad_users = user_counts[user_counts < user_min_interactions].index + bad_items = item_counts[item_counts < item_min_interactions].index + + if len(bad_users) == 0 and len(bad_items) == 0: + break + + df = df[~df[user_col].isin(bad_users) & ~df[item_col].isin(bad_items)] + df = df.reset_index(drop=True) + + return df diff --git a/rectools/semantic/data_handling/loo_split.py b/rectools/semantic/data_handling/loo_split.py new file mode 100644 index 00000000..891107ed --- /dev/null +++ b/rectools/semantic/data_handling/loo_split.py @@ -0,0 +1,47 @@ +import typing as tp + +import pandas as pd + + +def loo_split( + interactions: pd.DataFrame, + user_col: str = "user_id", + timestamp_col: str = "timestamp", +) -> tp.Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """Leave-one-out split on an interactions DataFrame. + + For each user, the last interaction goes to test, the second-to-last + to validation, and all earlier interactions to train. Users with + fewer than 3 interactions are dropped. + + Parameters + ---------- + interactions : pd.DataFrame + Interaction history. + user_col : str + Name of the user ID column. + timestamp_col : str + Name of the timestamp column. If present, interactions are + sorted by (user_col, timestamp_col) before splitting. + + Returns + ------- + train_df, val_df, test_df : pd.DataFrame + Train contains all but the last 2 interactions per user. + Val contains all but the last interaction per user. + Test contains all interactions. + """ + df = interactions + if timestamp_col in df.columns: + df = df.sort_values([user_col, timestamp_col]) + + grouped = df.groupby(user_col) + valid_users = grouped.filter(lambda x: len(x) >= 3) + grouped = valid_users.groupby(user_col) + + cumcounts = grouped.cumcount(ascending=False) + train_df = valid_users[cumcounts >= 2].reset_index(drop=True) + val_df = valid_users[cumcounts >= 1].reset_index(drop=True) + test_df = valid_users.reset_index(drop=True) + + return train_df, val_df, test_df diff --git a/rectools/semantic/metrics.py b/rectools/semantic/metrics.py new file mode 100644 index 00000000..7a85fa3f --- /dev/null +++ b/rectools/semantic/metrics.py @@ -0,0 +1,61 @@ +import typing as tp +from collections import Counter + + +def coverage_k(all_topk_items: tp.List[tp.List[int]], k: int, num_items: int) -> float: + """Calculate fraction of catalog items that appear in at least one user's top-K. + + Parameters + ---------- + all_topk_items : tp.List[tp.List[int]] + tp.List of all users' top-Ks + k : int + Users' top-K size + num_items : int + Number of unique items in dataset + + Returns + ------- + float + Coverage@k. + """ + recommended = set() + for items in all_topk_items: + for item in items[:k]: + recommended.add(item) + return len(recommended) / num_items + + +def gini_k(all_topk_items: tp.List[tp.List[int]], k: int) -> float: + """Gini coefficient over item recommendation frequencies in top-K. + + 0 = every item recommended equally often, 1 = all recommendations + concentrate on a single item. + + Parameters + ---------- + all_topk_items : _type_ + tp.List of all users' top-Ks + k : int + Users' top-K size + + Returns + ------- + float + Gini@k. + """ + counts: Counter[int] = Counter() + for items in all_topk_items: + for item in items[:k]: + counts[item] += 1 + + if len(counts) == 0: + return 0.0 + + freqs = sorted(counts.values()) + n = len(freqs) + cumulative = 0.0 + for i, freq in enumerate(freqs): + cumulative += (2 * (i + 1) - n - 1) * freq + total = sum(freqs) + return cumulative / (n * total) diff --git a/rectools/semantic/modules/__init__.py b/rectools/semantic/modules/__init__.py new file mode 100644 index 00000000..046923c6 --- /dev/null +++ b/rectools/semantic/modules/__init__.py @@ -0,0 +1,12 @@ +"""torch.nn modules used in TIGER model and RQ-VAE""" + +from .mlp import MLP +from .transformer_blocks import T5DecoderLayer, T5EncoderLayer, T5RMSNorm, T5RelativePositionBias + +__all__ = [ + "MLP", + "T5DecoderLayer", + "T5EncoderLayer", + "T5RMSNorm", + "T5RelativePositionBias", +] diff --git a/rectools/semantic/modules/mlp.py b/rectools/semantic/modules/mlp.py new file mode 100644 index 00000000..de7a624f --- /dev/null +++ b/rectools/semantic/modules/mlp.py @@ -0,0 +1,66 @@ +import typing as tp + +from torch import Tensor, nn + + +class MLP(nn.Module): + """An implementation of multilayer perceptron. + + Parameters + ---------- + input_dim : int + Input dimension. + hidden_dims : tp.List[int] + Dimensions of hidden layers. + out_dim : int + Output dimension. + dropout : float, optional + Dropout probability, by default 0.0 + normalize : bool, optional + Whether to apply batch normalization, by default False + """ + + def __init__( + self, + input_dim: int, + hidden_dims: tp.List[int], + out_dim: int, + dropout: float = 0.0, + normalize: bool = False, + ): + super().__init__() + + self.input_dim = input_dim + self.hidden_dims = hidden_dims + self.out_dim = out_dim + self.dropout = dropout + self.normalize = normalize + + dims = [self.input_dim] + self.hidden_dims + [self.out_dim] + + self.mlp = nn.Sequential() + for i, (in_d, out_d) in enumerate(zip(dims[:-1], dims[1:])): + self.mlp.append(nn.Linear(in_d, out_d, bias=False)) + + if self.normalize: + self.mlp.append(nn.BatchNorm1d(num_features=out_d)) + + if i != len(dims) - 2: + self.mlp.append(nn.ReLU()) + + if dropout != 0: + self.mlp.append(nn.Dropout(dropout)) + + self.apply(self._init_weights) + + def forward(self, x: Tensor) -> Tensor: + """Run a forward pass.""" + assert x.shape[-1] == self.input_dim, f"Invalid input dim: Expected {self.input_dim}, found {x.shape[-1]}" + return self.mlp(x) + + # We just initialize the module with normal distribution as the paper said + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, nn.Linear): + nn.init.xavier_normal_(module.weight.data) + if module.bias is not None: + module.bias.data.fill_(0.0) diff --git a/rectools/semantic/modules/transformer_blocks.py b/rectools/semantic/modules/transformer_blocks.py new file mode 100644 index 00000000..280f4884 --- /dev/null +++ b/rectools/semantic/modules/transformer_blocks.py @@ -0,0 +1,609 @@ +import math +import typing as tp + +import torch +from torch import nn +from torch.nn.functional import scaled_dot_product_attention + + +class PointWiseFeedForward(nn.Module): + """ + Feed-Forward network to introduce nonlinearity into the transformer model. + This implementation is the one used by SASRec authors. + + Parameters + ---------- + n_factors : int + Latent embeddings size. + n_factors_ff : int + How many hidden units to use in the network. + dropout_rate : float + Probability of a hidden unit to be zeroed. + activation: torch.nn.Module + Activation function module. + bias: bool, default ``True`` + If ``True``, add bias to linear layers. + """ + + def __init__( + self, + n_factors: int, + n_factors_ff: int, + dropout_rate: float, + activation: torch.nn.Module, + bias: bool = True, + ) -> None: + super().__init__() + self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias) + self.ff_dropout_1 = torch.nn.Dropout(dropout_rate) + self.ff_activation = activation + self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias) + + def forward(self, seqs: torch.Tensor) -> torch.Tensor: + """ + Forward pass. + + Parameters + ---------- + seqs : torch.Tensor + User sequences of item embeddings. + + Returns + ------- + torch.Tensor + User sequence that passed through all layers. + """ + output = self.ff_activation(self.ff_linear_1(seqs)) + fin = self.ff_linear_2(self.ff_dropout_1(output)) + return fin + + +class SwigluFeedForward(nn.Module): + """ + Feed-Forward network to introduce nonlinearity into the transformer model. + This implementation is based on FuXi and LLama SwigLU https://arxiv.org/pdf/2502.03036, + LiGR https://arxiv.org/pdf/2502.03417 + + Parameters + ---------- + n_factors : int + Latent embeddings size. + n_factors_ff : int + How many hidden units to use in the network. + dropout_rate : float + Probability of a hidden unit to be zeroed. + bias: bool, default ``True`` + If ``True``, add bias to linear layers. + """ + + def __init__(self, n_factors: int, n_factors_ff: int, dropout_rate: float, bias: bool = True) -> None: + super().__init__() + self.ff_linear_1 = nn.Linear(n_factors, n_factors_ff, bias=bias) + self.ff_dropout_1 = torch.nn.Dropout(dropout_rate) + self.ff_activation = torch.nn.SiLU() + self.ff_linear_2 = nn.Linear(n_factors_ff, n_factors, bias=bias) + self.ff_linear_3 = nn.Linear(n_factors, n_factors_ff, bias=bias) + + def forward(self, seqs: torch.Tensor) -> torch.Tensor: + """ + Forward pass. + + Parameters + ---------- + seqs : torch.Tensor + User sequences of item embeddings. + + Returns + ------- + torch.Tensor + User sequence that passed through all layers. + """ + output = self.ff_activation(self.ff_linear_1(seqs)) * self.ff_linear_3(seqs) + fin = self.ff_linear_2(self.ff_dropout_1(output)) + return fin + + +def init_feed_forward( + n_factors: int, + ff_factors_multiplier: int, + dropout_rate: float, + ff_activation: str, + bias: bool = True, +) -> nn.Module: + """ + Initialise Feed-Forward network with one of activation functions: "swiglu", "relu", "gelu". + + Parameters + ---------- + n_factors : int + Latent embeddings size. + ff_factors_multiplier : int + How many hidden units to use in the network. + dropout_rate : float + Probability of a hidden unit to be zeroed. + ff_activation : {"swiglu", "relu", "gelu"} + Activation function to use. + bias: bool, default ``True`` + If ``True``, add bias to linear layers. + + Returns + ------- + nn.Module + Feed-Forward network. + """ + if ff_activation == "swiglu": + return SwigluFeedForward(n_factors, n_factors * ff_factors_multiplier, dropout_rate, bias=bias) + if ff_activation == "gelu": + return PointWiseFeedForward( + n_factors, + n_factors * ff_factors_multiplier, + dropout_rate, + activation=torch.nn.GELU(), + bias=bias, + ) + if ff_activation == "relu": + return PointWiseFeedForward( + n_factors, + n_factors * ff_factors_multiplier, + dropout_rate, + activation=torch.nn.ReLU(), + bias=bias, + ) + raise ValueError(f"Unsupported ff_activation: {ff_activation}") + + +class T5RMSNorm(nn.Module): + """ + T5-style RMSNorm: normalise by root-mean-square, then scale. + + Unlike ``nn.LayerNorm`` this has **no bias** and does **not** + subtract the mean -- it only divides by the RMS and applies a + learned scale (weight) vector. + + Parameters + ---------- + d_model : int + Feature dimension. + eps : float + Epsilon for numerical stability. + """ + + def __init__(self, d_model: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(d_model)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run a forward pass.""" + rms = x.float().pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt() + return (x.float() * rms).to(x.dtype) * self.weight + + +class T5RelativePositionBias(nn.Module): + """ + Learned relative-position bias table with log-spaced bucketing, as + described in the T5 paper (Raffel et al., 2019). + + The bias is added directly to the attention logits and is shared + across all layers of the encoder (or decoder). + + Parameters + ---------- + num_heads : int + Number of attention heads (one scalar bias per head). + bidirectional : bool + ``True`` for the encoder (positions attend in both directions), + ``False`` for the decoder (causal -- only attends to earlier + positions). + num_buckets : int + Number of distance buckets. + max_distance : int + Distances beyond this value are clamped to the last bucket. + """ + + def __init__( + self, + num_heads: int, + bidirectional: bool = True, + num_buckets: int = 32, + max_distance: int = 128, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.bidirectional = bidirectional + self.num_buckets = num_buckets + self.max_distance = max_distance + + self.relative_attention_bias = nn.Embedding(num_buckets, num_heads) + + @staticmethod + def _relative_position_bucket( + relative_position: torch.Tensor, + bidirectional: bool, + num_buckets: int, + max_distance: int, + ) -> torch.Tensor: + """Map signed relative positions to bucket indices.""" + ret = torch.zeros_like(relative_position) + + if bidirectional: + num_buckets //= 2 + ret += (relative_position > 0).long() * num_buckets + n = relative_position.abs() + else: + # Clamp future positions to 0 (causal -- no peeking ahead). + n = (-relative_position).clamp(min=0) + + # Half the buckets are for exact small distances. + max_exact = num_buckets // 2 + is_small = n < max_exact + + # The other half are for log-spaced larger distances. + val_if_large = ( + max_exact + + (torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)).long() + ) + val_if_large = val_if_large.clamp(max=num_buckets - 1) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, query_len: int, key_len: int, device: torch.device) -> torch.Tensor: + """ + Compute position bias. + + Returns + ------- + Tensor [1, num_heads, query_len, key_len] + Additive bias to be added to the attention logits. + """ + q_pos = torch.arange(query_len, dtype=torch.long, device=device) + k_pos = torch.arange(key_len, dtype=torch.long, device=device) + # relative_position[i, j] = j - i (key pos minus query pos) + relative_position = k_pos.unsqueeze(0) - q_pos.unsqueeze(1) + + buckets = self._relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + # [query_len, key_len, num_heads] + values = self.relative_attention_bias(buckets) + # -> [1, num_heads, query_len, key_len] + return values.permute(2, 0, 1).unsqueeze(0) + + +class T5Attention(nn.Module): + """ + Multi-head attention with separate key/value dimension, as in T5. + + Unlike ``nn.MultiheadAttention`` where the per-head dimension is + always ``embed_dim // num_heads``, this module lets you choose + ``d_kv`` freely. Projections have no bias (following T5). + + Uses ``scaled_dot_product_attention`` for the core computation, + which dispatches to FlashAttention / memory-efficient backends + automatically. + + Parameters + ---------- + d_model : int + Model (query input) dimension. + num_heads : int + Number of attention heads. + d_kv : int + Per-head dimension for keys and values (and queries). + dropout : float + Attention dropout probability (applied only during training). + kv_input_dim : int, optional + Input dimension for key/value tensors. Defaults to ``d_model``. + Set to a different value for cross-attention when the encoder + output dimension differs from the decoder's ``d_model``. + """ + + # nn.TransformerEncoder/Decoder access self_attn.batch_first + batch_first: bool = True + + def __init__( + self, + d_model: int, + num_heads: int, + d_kv: int, + dropout: float = 0.0, + kv_input_dim: tp.Optional[int] = None, + ) -> None: + super().__init__() + if kv_input_dim is None: + kv_input_dim = d_model + + self.num_heads = num_heads + self.d_kv = d_kv + self.inner_dim = num_heads * d_kv + + self.q_proj = nn.Linear(d_model, self.inner_dim, bias=False) + self.k_proj = nn.Linear(kv_input_dim, self.inner_dim, bias=False) + self.v_proj = nn.Linear(kv_input_dim, self.inner_dim, bias=False) + self.o_proj = nn.Linear(self.inner_dim, d_model, bias=False) + + self.dropout = dropout + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: tp.Optional[torch.Tensor] = None, + key_padding_mask: tp.Optional[torch.Tensor] = None, + need_weights: bool = False, + is_causal: bool = False, + position_bias: tp.Optional[torch.Tensor] = None, + ) -> tp.Tuple[torch.Tensor, None]: + """ + Parameters + ---------- + query : Tensor [B, T_q, d_model] + key : Tensor [B, T_kv, kv_input_dim] + value : Tensor [B, T_kv, kv_input_dim] + attn_mask : Tensor, optional + Additive float mask broadcastable to ``(B, num_heads, T_q, T_kv)``. + ``-inf`` blocks attention. + key_padding_mask : BoolTensor [B, T_kv], optional + ``True`` at positions that should be masked (padding). + need_weights : bool + Ignored (always returns ``None`` for weights). + is_causal : bool + If ``True``, applies a causal mask inside SDPA. + position_bias : Tensor, optional + Additive relative-position bias broadcastable to + ``(B, num_heads, T_q, T_kv)``. Produced by + ``T5RelativePositionBias``. + + Returns + ------- + (output, None) + output : Tensor [B, T_q, d_model] + """ + B, T_q, _ = query.shape + T_kv = key.shape[1] + + # Project and reshape to (B, num_heads, T, d_kv) + q = self.q_proj(query).view(B, T_q, self.num_heads, self.d_kv).transpose(1, 2) + k = self.k_proj(key).view(B, T_kv, self.num_heads, self.d_kv).transpose(1, 2) + v = self.v_proj(value).view(B, T_kv, self.num_heads, self.d_kv).transpose(1, 2) + + # Convert key_padding_mask [B, T_kv] to additive attention mask + # broadcastable to (B, num_heads, T_q, T_kv). + # The mask may arrive as bool (True=pad) or as a float tensor + # (0/-inf) -- nn.TransformerEncoder runs _canonical_mask which + # converts bool to float before calling the layer. + effective_mask = attn_mask + if key_padding_mask is not None: + if key_padding_mask.dtype == torch.bool: + pad_mask = torch.zeros(key_padding_mask.shape, dtype=q.dtype, device=q.device).masked_fill_( + key_padding_mask, float("-inf") + ) + else: + pad_mask = key_padding_mask.to(dtype=q.dtype) + pad_mask = pad_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, T_kv] + if effective_mask is not None: + effective_mask = effective_mask + pad_mask + else: + effective_mask = pad_mask + + # Add relative position bias to attention logits. + if position_bias is not None: + if effective_mask is not None: + effective_mask = effective_mask + position_bias + else: + effective_mask = position_bias + + # When an explicit mask is provided, don't also pass is_causal=True + # -- SDPA does not allow both at the same time. + if effective_mask is not None: + is_causal = False + + out = scaled_dot_product_attention( # pylint: disable=not-callable + q, + k, + v, + attn_mask=effective_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=is_causal, + ) # [B, num_heads, T_q, d_kv] + + out = out.transpose(1, 2).contiguous().view(B, T_q, self.inner_dim) + out = self.o_proj(out) + return out, None + + +class T5EncoderLayer(nn.Module): + """ + Pre-norm Transformer encoder layer with T5-style attention. + + Architecture per step:: + + x_norm = RMSNorm(x) + x = x + Dropout(SelfAttention(x_norm, x_norm, x_norm)) + x_norm = RMSNorm(x) + x = x + Dropout(FFN(x_norm)) + + Parameters + ---------- + d_model : int + Model dimension. + num_heads : int + Number of attention heads. + d_kv : int + Per-head key/value (and query) dimension. + dim_feedforward : int + Inner dimension of the feed-forward network. + dropout : float + Dropout probability. + activation : str + FFN activation (``"relu"`` or ``"gelu"``). + """ + + def __init__( + self, + d_model: int, + num_heads: int, + d_kv: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + ) -> None: + super().__init__() + self.self_attn = T5Attention( + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + dropout=dropout, + ) + self.norm1 = T5RMSNorm(d_model) + self.norm2 = T5RMSNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.activation: nn.Module + if activation == "relu": + self.activation = nn.ReLU() + elif activation == "gelu": + self.activation = nn.GELU() + else: + raise ValueError(f"Unsupported activation: {activation}") + + def forward( + self, + src: torch.Tensor, + src_key_padding_mask: tp.Optional[torch.Tensor] = None, + position_bias: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run a forward pass.""" + x = src + x_norm = self.norm1(x) + attn_out, _ = self.self_attn( + x_norm, + x_norm, + x_norm, + key_padding_mask=src_key_padding_mask, + position_bias=position_bias, + ) + x = x + self.dropout1(attn_out) + + x_norm = self.norm2(x) + ff_out = self.linear2(self.activation(self.linear1(x_norm))) + x = x + self.dropout2(ff_out) + return x + + +class T5DecoderLayer(nn.Module): + """ + Pre-norm Transformer decoder layer with T5-style attention. + + Architecture per step:: + + x_norm = RMSNorm(x) + x = x + Dropout(CausalSelfAttention(x_norm, x_norm, x_norm)) + x_norm = RMSNorm(x) + x = x + Dropout(CrossAttention(x_norm, memory, memory)) + x_norm = RMSNorm(x) + x = x + Dropout(FFN(x_norm)) + + Parameters + ---------- + d_model : int + Model dimension. + num_heads : int + Number of attention heads. + d_kv : int + Per-head key/value (and query) dimension. + dim_feedforward : int + Inner dimension of the feed-forward network. + dropout : float + Dropout probability. + activation : str + FFN activation (``"relu"`` or ``"gelu"``). + """ + + def __init__( + self, + d_model: int, + num_heads: int, + d_kv: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + ) -> None: + super().__init__() + self.self_attn = T5Attention( + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + dropout=dropout, + ) + self.cross_attn = T5Attention( + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + dropout=dropout, + kv_input_dim=d_model, + ) + self.norm1 = T5RMSNorm(d_model) + self.norm2 = T5RMSNorm(d_model) + self.norm3 = T5RMSNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.activation: nn.Module + if activation == "relu": + self.activation = nn.ReLU() + elif activation == "gelu": + self.activation = nn.GELU() + else: + raise ValueError(f"Unsupported activation: {activation}") + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: tp.Optional[torch.Tensor] = None, + memory_key_padding_mask: tp.Optional[torch.Tensor] = None, + position_bias: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Run a forward pass.""" + x = tgt + + # Self-attention (causal) + x_norm = self.norm1(x) + sa_out, _ = self.self_attn( + x_norm, + x_norm, + x_norm, + attn_mask=tgt_mask, + position_bias=position_bias, + ) + x = x + self.dropout1(sa_out) + + # Cross-attention (no position bias per T5 design) + x_norm = self.norm2(x) + ca_out, _ = self.cross_attn( + x_norm, + memory, + memory, + key_padding_mask=memory_key_padding_mask, + ) + x = x + self.dropout2(ca_out) + + # Feed-forward + x_norm = self.norm3(x) + ff_out = self.linear2(self.activation(self.linear1(x_norm))) + x = x + self.dropout3(ff_out) + return x diff --git a/rectools/semantic/tiger/__init__.py b/rectools/semantic/tiger/__init__.py new file mode 100644 index 00000000..a962e05c --- /dev/null +++ b/rectools/semantic/tiger/__init__.py @@ -0,0 +1,11 @@ +"""TIGER generative recommender model with semantic ID inputs""" + +from .lightning import TIGERLightning +from .model import TIGERModel +from .module import TIGERNet + +__all__ = [ + "TIGERLightning", + "TIGERModel", + "TIGERNet", +] diff --git a/rectools/semantic/tiger/lightning.py b/rectools/semantic/tiger/lightning.py new file mode 100644 index 00000000..4aa57e7d --- /dev/null +++ b/rectools/semantic/tiger/lightning.py @@ -0,0 +1,233 @@ +import typing as tp + +import numpy as np +import pytorch_lightning as pl +import torch +from pytorch_lightning.utilities.types import STEP_OUTPUT, OptimizerLRScheduler + +from rectools.semantic import LRScheduleType, OptimizerType +from rectools.semantic.metrics import coverage_k, gini_k +from rectools.semantic.tokenizer import SIDTokenizer + +from .loss import compute_tiger_loss +from .module import TIGERNet + + +class TIGERLightning(pl.LightningModule): + """A ``pytorch_lightninig` wrapper around TIGERNet which used used for + training, validation and testing of the model. + + Parameters + ---------- + model : TIGERNet + Model to train + tokenizer : SIDTokenizer + Semantic ID tokenizer + optimizer_name : OptimizerType, optional + Which optimizer to use, by default "adamw" + beam_size : int, optional + Beam size for model inference, by default 20 + top_k : int, optional + Used for @k metrics like HR@k or MRR@k, by default 10 + lr : float, optional + Model learning rate, by default 5e-4 + weight_decay : float, optional + Optimizer weight decay, by default 1e-4 + warmup_steps : int | float, optional + Learning rate warm-up steps, by default 300 + lr_schedule : tp.Optional[LRScheduleType], optional + Learning rate schedule to use. Constant LR if set to None, by default None + max_iters : int, optional + Maximum number of iterations for training, by default 1 + num_items : tp.Optional[int], optional + Number of unique items, by default None + """ + + def __init__( + self, + model: TIGERNet, + tokenizer: SIDTokenizer, + optimizer_name: OptimizerType = "adamw", + beam_size: int = 20, + top_k: int = 10, + lr: float = 5e-4, + weight_decay: float = 1e-4, + warmup_steps: int | float = 300, + lr_schedule: tp.Optional[LRScheduleType] = None, + max_iters: int = 1, + num_items: tp.Optional[int] = None, + decode_rng: tp.Optional[np.random.RandomState] = None, + ): + super().__init__() + + self.model = model + self.tokenizer = tokenizer + self.beam_size = beam_size + self.top_k = top_k + self.num_items = num_items + + self.optimizer_name = optimizer_name + self.lr = lr + self.lr_schedule = lr_schedule + self.weight_decay = weight_decay + self.max_iters = max_iters + self.decode_rng = decode_rng + + self._all_topk_items: tp.List[tp.List[int]] = [] + + if isinstance(warmup_steps, float) and 0 < warmup_steps < 1: + self.warmup_steps = round(max_iters * warmup_steps) + else: + self.warmup_steps = int(warmup_steps) + + def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTPUT: + input_ids = batch["input_ids"] + dec_input = batch["dec_input"] + labels = batch["labels"] + + enc_padding_mask = input_ids == TIGERNet.PAD_TOKEN_ID + + logits = self.model(input_ids, dec_input, enc_padding_mask) + loss = compute_tiger_loss(logits, labels) + + self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True) + return loss + + def _valtest_step( # pylint: disable=too-many-locals + self, batch: tp.Dict[str, torch.Tensor], batch_idx: int, mode: tp.Literal["val", "test"] + ) -> STEP_OUTPUT: + prefix = "" if mode == "test" else "val_" + input_sids = batch["input_ids"] + labels = batch["labels"].cpu().tolist() + + metrics = {} + all_topk_items = [] + + enc_padding_mask = input_sids == TIGERNet.PAD_TOKEN_ID + + codes, _scores = self.model.generate(input_sids, enc_padding_mask=enc_padding_mask, beam_size=self.beam_size) + batch_size, num_beams = codes.size(0), codes.size(1) + + # Batch-decode: flatten (batch, beam, depth) -> list of tuples, decode once + codes_cpu = codes.cpu().tolist() + all_sids = [tuple(codes_cpu[i][j]) for i in range(batch_size) for j in range(num_beams)] + all_decoded: tp.List[tp.Optional[int]] = self.tokenizer.decode( # type: ignore[assignment] + all_sids, + rng=self.decode_rng, + ) + + # Reshape decoded items back to (batch, beam) + decoded_grid = [all_decoded[i * num_beams : (i + 1) * num_beams] for i in range(batch_size)] + + # Per-row deduplication + rank finding (small Python loop) + ranks = np.zeros(batch_size, dtype=np.int64) + valid_sids = 0 + for i in range(batch_size): + row = decoded_grid[i] + target = labels[i] + seen = set() + unique_items = [] + rank_pos = 0 + found = False + for item_id in row: + if item_id is not None and item_id not in seen: + valid_sids += 1 + rank_pos += 1 + seen.add(item_id) + unique_items.append(item_id) + if not found and item_id == target: + ranks[i] = rank_pos + found = True + all_topk_items.append(unique_items) + + self._all_topk_items.extend(all_topk_items) + + # Vectorized metric computation + in_top_k = (ranks > 0) & (ranks <= self.top_k) + top_k_ranks = ranks[in_top_k] + + hit_sum = in_top_k.sum() + + metrics[f"{prefix}Hit@{self.top_k}"] = hit_sum / batch_size + metrics[f"{prefix}NDCG@{self.top_k}"] = ( + float((1.0 / np.log2(top_k_ranks + 1)).sum()) if hit_sum > 0 else 0.0 + ) / batch_size + metrics[f"{prefix}MRR@{self.top_k}"] = (float((1.0 / top_k_ranks).sum()) if hit_sum > 0 else 0.0) / batch_size + self.log_dict(metrics, on_epoch=True, on_step=False, prog_bar=True) + return metrics + + def _on_valtest_epoch_end(self, mode: tp.Literal["val", "test"]) -> STEP_OUTPUT: + prefix = "" if mode == "test" else "val_" + metrics = {} + metrics[f"{prefix}Gini@{self.top_k}"] = gini_k(self._all_topk_items, self.top_k) + if self.num_items is not None: + metrics[f"{prefix}Coverage@{self.top_k}"] = coverage_k(self._all_topk_items, self.top_k, self.num_items) + self.log_dict(metrics, on_epoch=True, on_step=False, prog_bar=True) + return metrics + + def on_validation_epoch_start(self) -> None: + self._all_topk_items = [] + + def validation_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTPUT: + return self._valtest_step(batch, batch_idx, mode="val") + + def on_validation_epoch_end(self) -> None: + self._on_valtest_epoch_end(mode="val") + + def on_test_epoch_start(self) -> None: + self._all_topk_items = [] + + def test_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> STEP_OUTPUT: + return self._valtest_step(batch, batch_idx, mode="test") + + def on_test_epoch_end(self) -> None: + self._on_valtest_epoch_end(mode="test") + + def configure_optimizers(self) -> OptimizerLRScheduler: + optimizer: torch.optim.Optimizer + if self.optimizer_name == "adamw": + optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) + elif self.optimizer_name == "adagrad": + optimizer = torch.optim.Adagrad(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) + elif self.optimizer_name == "adam": + optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay) + else: + raise ValueError(f"Unknown optimizer: {self.optimizer_name}") + + if self.lr_schedule is None: + return optimizer + + schedulers: tp.List[torch.optim.lr_scheduler.LRScheduler] = [] + milestones: tp.List[int] = [] + if self.warmup_steps > 0: + schedulers.append( + torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.1, total_iters=self.warmup_steps) + ) + milestones = [self.warmup_steps] + if self.lr_schedule == "cosine": + t_0 = max(1, self.max_iters) + schedulers.append( + torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + optimizer, + T_0=t_0, + ) + ) + elif self.lr_schedule == "linear": + schedulers.append(torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.0)) + else: + schedulers.append( + torch.optim.lr_scheduler.ConstantLR( + optimizer, + factor=1.0, + ) + ) + scheduler = torch.optim.lr_scheduler.SequentialLR( + optimizer, + schedulers, + milestones=milestones, + ) + + return { + "optimizer": optimizer, + "lr_scheduler": {"scheduler": scheduler, "interval": "step"}, + } diff --git a/rectools/semantic/tiger/loss.py b/rectools/semantic/tiger/loss.py new file mode 100644 index 00000000..13664612 --- /dev/null +++ b/rectools/semantic/tiger/loss.py @@ -0,0 +1,28 @@ +import typing as tp + +import torch +from torch.nn.functional import cross_entropy + + +def compute_tiger_loss(logits: tp.List[torch.Tensor], labels: torch.Tensor) -> torch.Tensor: + """ + Compute cross-entropy loss for TIGER single-target prediction. + + Parameters + ---------- + logits : list of Tensor, each [B, codebook_size_d] + One tensor per codebook level. + labels : LongTensor [B, sid_len] + Raw SID codes (0-based, no offset) for the target item. + """ + total_loss = torch.tensor(0.0) + sid_len = labels.shape[-1] + + for d in range(sid_len): + total_loss = total_loss + cross_entropy( + logits[d], # [B, codebook_size_d] + labels[:, d], # [B] + ignore_index=-100, + ) + + return total_loss / sid_len diff --git a/rectools/semantic/tiger/model.py b/rectools/semantic/tiger/model.py new file mode 100644 index 00000000..585518d4 --- /dev/null +++ b/rectools/semantic/tiger/model.py @@ -0,0 +1,396 @@ +import json +import os +import typing as tp + +import numpy as np +import pandas as pd +import pytorch_lightning as pl +import torch +from pytorch_lightning.callbacks import LearningRateMonitor, ModelSummary +from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from torch.utils.data import DataLoader +from tqdm.auto import trange + +from rectools.semantic import LRScheduleType, OptimizerType +from rectools.semantic.data_handling import PaddingCollateFn, TIGERDataset +from rectools.semantic.tokenizer import SIDTokenizer + +from .lightning import TIGERLightning +from .module import TIGERNet + + +class TIGERModel: # pylint: disable=too-many-instance-attributes + """TIGER generative recommender model with semantic IDs. + + Given a pretrained SIDTokenizer, trains a TIGER generative recommender model. + + Parameters + ---------- + tokenizer : SIDTokenizer + Pretrained semantic ID tokenizer. + """ + + # pylint: disable=too-many-arguments,too-many-locals + def __init__( + self, + tokenizer: SIDTokenizer, + # TIGER model hyperparams + hidden_units: int = 256, + num_blocks: int = 2, + num_heads: int = 1, + dropout_rate: float = 0.1, + max_length: int = 200, + ff_dim: tp.Optional[int] = None, + d_kv: tp.Optional[int] = None, + # Training hyperparams + optimizer: OptimizerType = "adamw", + lr: float = 1e-3, + lr_schedule: tp.Optional[LRScheduleType] = None, + warmup_steps: int = 300, + weight_decay: float = 1e-4, + max_epochs: int = 100, + patience: tp.Optional[int] = 10, + batch_size: int = 256, + eval_batch_size: int = 64, + num_workers: int = 4, + beam_size: int = 20, + top_k: int = 10, + train_only_last: bool = True, + random_seed: tp.Optional[int] = None, + # General + device: tp.Optional[str] = None, + ): + self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + self.tokenizer = tokenizer + self.codebook_sizes = list(tokenizer.quantizer.codebook_sizes) + + # Training hyperparams + self.optimizer = optimizer + self.lr = lr + self.lr_schedule = lr_schedule + self.warmup_steps = warmup_steps + self.weight_decay = weight_decay + self.max_epochs = max_epochs + self.patience = patience + self.batch_size = batch_size + self.eval_batch_size = eval_batch_size + self.num_workers = num_workers + self.beam_size = beam_size + self.top_k = top_k + self.train_only_last = train_only_last + self.random_seed = random_seed + self._decode_rng = np.random.RandomState(random_seed) if random_seed is not None else None + + # TIGER model hyperparams + self.hidden_units = hidden_units + self.num_blocks = num_blocks + self.num_heads = num_heads + self.dropout_rate = dropout_rate + self.max_length = max_length + self.ff_dim = ff_dim + self.d_kv = d_kv + + self.model = TIGERNet( + codebook_sizes=self.codebook_sizes, + hidden_units=hidden_units, + num_blocks=num_blocks, + num_heads=num_heads, + dropout_rate=dropout_rate, + max_length=max_length, + ff_dim=ff_dim, + d_kv=d_kv, + ).to(self.device) + + def _tiger_dataset_kwargs(self) -> dict: + return { + "tokenizer": self.tokenizer, + "codebook_sizes": self.codebook_sizes, + "max_length": self.max_length, + "codeword_offset": TIGERNet.CODEWORD_OFFSET, + "bos_token_id": TIGERNet.BOS_TOKEN_ID, + } + + def fit( + self, + train_df: pd.DataFrame, + val_df: pd.DataFrame, + monitor: tp.Literal["Hit", "MRR", "NDCG"] = "Hit", + ) -> None: + """Train the TIGER model. + + Parameters + ---------- + train_df : pd.DataFrame + Training interactions (user_id, item_id, optionally timestamp). + val_df : pd.DataFrame + Validation interactions (same schema). + + Notes + ----- + If multiple items share the same Semantic ID, decoding during + validation remains stochastic. Set ``random_seed`` at model + initialization to make this collision resolution reproducible. + """ + num_items = len(set(train_df["item_id"]).union(val_df["item_id"])) + + ds_kwargs = self._tiger_dataset_kwargs() + + train_dataset = TIGERDataset( + interactions=train_df, + only_last=self.train_only_last, + eval_mode=False, + **ds_kwargs, + ) + + val_dataset = TIGERDataset( + interactions=val_df, + only_last=True, + eval_mode=True, + **ds_kwargs, + ) + + train_loader = DataLoader( + train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers, + collate_fn=PaddingCollateFn(padding_value=TIGERNet.PAD_TOKEN_ID), + pin_memory=True, + ) + + val_loader = DataLoader( + val_dataset, + batch_size=self.eval_batch_size, + num_workers=self.num_workers, + collate_fn=PaddingCollateFn(padding_value=TIGERNet.PAD_TOKEN_ID), + pin_memory=True, + ) + + lightning_model = TIGERLightning( + model=self.model, + tokenizer=self.tokenizer, + optimizer_name=self.optimizer, + lr=self.lr, + lr_schedule=self.lr_schedule, + warmup_steps=self.warmup_steps, + max_iters=self.max_epochs * len(train_loader) - self.warmup_steps, + weight_decay=self.weight_decay, + num_items=num_items, + beam_size=self.beam_size, + top_k=self.top_k, + ) + + callbacks = [LearningRateMonitor(logging_interval="step"), ModelSummary()] + if self.patience is not None: + callbacks.append( + EarlyStopping( + monitor=f"val_{monitor}@{self.top_k}", + patience=self.patience, + mode="max", + ) + ) + + trainer = pl.Trainer(max_epochs=self.max_epochs, callbacks=callbacks) + trainer.fit( + lightning_model, + train_dataloaders=train_loader, + val_dataloaders=val_loader, + ) + + self.model = lightning_model.model.to(self.device) + + @torch.no_grad() + def predict( + self, + interactions: pd.DataFrame, + top_k: int = 10, + ) -> pd.DataFrame: + """Generate top-k recommendations for each user. + + Parameters + ---------- + interactions : pd.DataFrame + User interaction histories (user_id, item_id, optionally timestamp). + top_k : int + Number of recommendations per user. + + Returns + ------- + pd.DataFrame + Recommendations with columns: user_id, item_id, score, rank. + + Notes + ----- + If multiple items share the same Semantic ID, recommendation + deduplication depends on stochastic SID decoding. Set + ``random_seed`` at model initialization to make outputs + reproducible. + """ + self.model.eval() + + df = interactions + if "timestamp" in df.columns: + df = df.sort_values(["user_id", "timestamp"]) + + user_sequences = df.groupby("user_id")["item_id"].agg(list) + user_ids = user_sequences.index.tolist() + sequences = user_sequences.tolist() + + offsets = np.cumsum([TIGERNet.CODEWORD_OFFSET] + self.codebook_sizes)[:-1] + + enc_tokens_list = [] + for seq in sequences: + seq = seq[-self.max_length :] + sids: tp.List[tp.Tuple[int]] = self.tokenizer.tokenize(seq) # type: ignore[assignment] + tokens: tp.List[int] = [] + for sid in sids: + tokens.extend(code + offsets[d] for d, code in enumerate(sid)) + enc_tokens_list.append(torch.tensor(tokens, dtype=torch.long)) + + rows: tp.List[tp.Tuple] = [] + n_users = len(enc_tokens_list) + + for start in trange(0, n_users, self.eval_batch_size, desc="Generating predictions"): + end = min(start + self.eval_batch_size, n_users) + batch_tokens = enc_tokens_list[start:end] + batch_user_ids = user_ids[start:end] + + batch_max_len = max(t.size(0) for t in batch_tokens) + enc_input = torch.full( + (len(batch_tokens), batch_max_len), + TIGERNet.PAD_TOKEN_ID, + dtype=torch.long, + device=self.device, + ) + for i, t in enumerate(batch_tokens): + enc_input[i, : t.size(0)] = t.to(self.device) + + enc_padding_mask = enc_input == TIGERNet.PAD_TOKEN_ID + + codes, scores = self.model.generate(enc_input, enc_padding_mask=enc_padding_mask, beam_size=top_k) + + batch_size, num_beams = codes.size(0), codes.size(1) + codes_cpu = codes.cpu().tolist() + scores_cpu = scores.cpu() + + all_sids = [tuple(codes_cpu[i][j]) for i in range(batch_size) for j in range(num_beams)] + all_decoded: tp.List[tp.Optional[int]] = self.tokenizer.decode( # type: ignore[assignment] + all_sids, + rng=self._decode_rng, + ) + + for i in range(batch_size): + uid = batch_user_ids[i] + row_decoded = all_decoded[i * num_beams : (i + 1) * num_beams] + row_scores = scores_cpu[i] + seen: tp.Set[int] = set() + rank = 0 + for j, item_id in enumerate(row_decoded): + if item_id is not None and item_id not in seen: + seen.add(item_id) + rank += 1 + rows.append((uid, item_id, row_scores[j].item(), rank)) + if rank >= top_k: + break + + return pd.DataFrame(rows, columns=["user_id", "item_id", "score", "rank"]) + + def evaluate( + self, + test_df: pd.DataFrame, + top_k: tp.Optional[int] = None, + ) -> tp.Mapping[str, float]: + """Evaluate the model on a test set. + + Parameters + ---------- + test_df : pd.DataFrame + Test interactions (user_id, item_id, optionally timestamp). + For each user, the last item is treated as the ground-truth target. + top_k : int, optional + Number of recommendations to generate. Defaults to ``self.top_k``. + + Returns + ------- + dict + Metric name -> value (Hit@k, NDCG@k, MRR@k, etc.). + + Notes + ----- + If multiple items share the same Semantic ID, decoding during + evaluation remains stochastic. Set ``random_seed`` at model + initialization to make metric computation reproducible. + """ + if top_k is None: + top_k = self.top_k + + num_items = test_df["item_id"].nunique() + + test_dataset = TIGERDataset( + interactions=test_df, + eval_mode=True, + only_last=True, + **self._tiger_dataset_kwargs(), + ) + + test_loader = DataLoader( + test_dataset, + batch_size=self.eval_batch_size, + num_workers=self.num_workers, + collate_fn=PaddingCollateFn(padding_value=TIGERNet.PAD_TOKEN_ID), + pin_memory=True, + ) + + lightning_model = TIGERLightning( + model=self.model, + tokenizer=self.tokenizer, + num_items=num_items, + beam_size=self.beam_size, + top_k=top_k, + decode_rng=self._decode_rng, + ) + + trainer = pl.Trainer() + results = trainer.test(lightning_model, dataloaders=test_loader)[0] + self.model = self.model.to(self.device) + return results + + def save(self, directory: str) -> None: + """Save the model to a directory. + + Writes ``tokenizer.pt``, ``model.pt``, and ``config.json``. + """ + os.makedirs(directory, exist_ok=True) + self.tokenizer.save(os.path.join(directory, "tokenizer.pt")) + torch.save(self.model.state_dict(), os.path.join(directory, "model.pt")) + + config = { + "hidden_units": self.hidden_units, + "num_blocks": self.num_blocks, + "num_heads": self.num_heads, + "dropout_rate": self.dropout_rate, + "max_length": self.max_length, + "ff_dim": self.ff_dim, + "d_kv": self.d_kv, + "random_seed": self.random_seed, + } + with open(os.path.join(directory, "config.json"), "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + + @classmethod + def load(cls, directory: str, device: tp.Optional[str] = None) -> "TIGERModel": + """Load a saved TIGERModel from a directory.""" + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + + with open(os.path.join(directory, "config.json"), encoding="utf-8") as f: + config = json.load(f) + + tokenizer = SIDTokenizer.load(os.path.join(directory, "tokenizer.pt"), device=device) + + tiger_model = cls(tokenizer=tokenizer, device=device, **config) + + state_dict = torch.load(os.path.join(directory, "model.pt"), map_location=device) + tiger_model.model.load_state_dict(state_dict) + + return tiger_model diff --git a/rectools/semantic/tiger/module.py b/rectools/semantic/tiger/module.py new file mode 100644 index 00000000..3efb95b0 --- /dev/null +++ b/rectools/semantic/tiger/module.py @@ -0,0 +1,542 @@ +import math +import typing as tp + +import torch +from torch import nn +from torch.nn.functional import log_softmax + +from rectools.semantic.modules import T5DecoderLayer, T5EncoderLayer, T5RMSNorm, T5RelativePositionBias + + +class T5Encoder(nn.Module): + """ + T5-style encoder stack with shared relative position bias. + + Replaces ``nn.TransformerEncoder`` for the T5 code path so that a + single ``T5RelativePositionBias`` is computed once and shared + across all layers. + + Parameters + ---------- + d_model : int + Model dimension. + num_heads : int + Number of attention heads. + d_kv : int + Per-head key/value (and query) dimension. + num_layers : int + Number of encoder layers. + dim_feedforward : int + FFN inner dimension. + dropout : float + Dropout probability. + activation : str + FFN activation. + relative_position_num_buckets : int + Number of buckets for relative position bias. + relative_position_max_distance : int + Max distance for relative position bucketing. + """ + + def __init__( + self, + d_model: int, + num_heads: int, + d_kv: int, + num_layers: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + relative_position_num_buckets: int = 32, + relative_position_max_distance: int = 128, + ) -> None: + super().__init__() + self.layers = nn.ModuleList( + [ + T5EncoderLayer( + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + ) + for _ in range(num_layers) + ] + ) + self.final_norm = T5RMSNorm(d_model) + self.position_bias = T5RelativePositionBias( + num_heads=num_heads, + bidirectional=True, + num_buckets=relative_position_num_buckets, + max_distance=relative_position_max_distance, + ) + + def forward( + self, + src: torch.Tensor, + src_key_padding_mask: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pos_bias = self.position_bias(query_len=src.size(1), key_len=src.size(1), device=src.device) + x = src + for layer in self.layers: + x = layer(x, src_key_padding_mask=src_key_padding_mask, position_bias=pos_bias) + return self.final_norm(x) + + +class T5Decoder(nn.Module): + """ + T5-style decoder stack with shared relative position bias. + + Replaces ``nn.TransformerDecoder`` for the T5 code path so that a + single ``T5RelativePositionBias`` (causal / unidirectional) is + computed once and shared across all layers. + + Parameters + ---------- + d_model : int + Model dimension. + num_heads : int + Number of attention heads. + d_kv : int + Per-head key/value (and query) dimension. + num_layers : int + Number of decoder layers. + dim_feedforward : int + FFN inner dimension. + dropout : float + Dropout probability. + activation : str + FFN activation. + relative_position_num_buckets : int + Number of buckets for relative position bias. + relative_position_max_distance : int + Max distance for relative position bucketing. + """ + + def __init__( + self, + d_model: int, + num_heads: int, + d_kv: int, + num_layers: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + relative_position_num_buckets: int = 32, + relative_position_max_distance: int = 128, + ) -> None: + super().__init__() + self.layers = nn.ModuleList( + [ + T5DecoderLayer( + d_model=d_model, + num_heads=num_heads, + d_kv=d_kv, + dim_feedforward=dim_feedforward, + dropout=dropout, + activation=activation, + ) + for _ in range(num_layers) + ] + ) + self.final_norm = T5RMSNorm(d_model) + self.position_bias = T5RelativePositionBias( + num_heads=num_heads, + bidirectional=False, + num_buckets=relative_position_num_buckets, + max_distance=relative_position_max_distance, + ) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: tp.Optional[torch.Tensor] = None, + memory_key_padding_mask: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + pos_bias = self.position_bias(query_len=tgt.size(1), key_len=tgt.size(1), device=tgt.device) + x = tgt + for layer in self.layers: + x = layer( + x, + memory, + tgt_mask=tgt_mask, + memory_key_padding_mask=memory_key_padding_mask, + position_bias=pos_bias, + ) + return self.final_norm(x) + + +class TIGERNet(nn.Module): # pylint: disable=too-many-instance-attributes + """ + Generative retrieval model for sequential recommendation. + + Items are represented as Semantic IDs -- tuples of discrete codewords + produced by RQ-VAE. The model is a Transformer encoder-decoder: + * Encoder input : flattened Semantic ID tokens of the user history. + * Decoder output: Semantic ID tokens of the next item, predicted + autoregressively one codeword at a time. + + Special tokens + -------------- + 0 -- padding token (in both encoder and decoder) + 1 -- BOS (beginning-of-sequence) token fed to the decoder at step 0 + + All codeword tokens are offset by +2 so that the vocabulary is: + {0: PAD, 1: BOS, 2 .. 2+sum(codebook_sizes)-1 : codewords} + Each codebook level gets its own contiguous slice of the vocabulary. + + Parameters + ---------- + codebook_sizes : list[int] + Number of codes per RQ-VAE level (e.g. [256, 256, 256]). + hidden_units : int + Dimensionality of the Transformer hidden states. + num_blocks : int + Number of Transformer layers in **each** of the encoder and decoder. + num_heads : int + Number of attention heads. + dropout_rate : float + Dropout probability. + max_length : int + Maximum number of items in the encoder history. + initializer_range : float + Std for weight initialisation. + ff_dim : int, optional + Feed-forward inner dimension. Defaults to ``4 * hidden_units``. + d_kv : int, optional + Per-head key/value (and query) projected dimension for T5-style + attention. When set, the encoder and decoder use T5-style layers + with RMSNorm and relative position bias (no absolute positional + embeddings). When ``None`` (default), standard + ``nn.TransformerEncoderLayer`` / ``nn.TransformerDecoderLayer`` + are used with learned absolute positional embeddings. + relative_position_num_buckets : int + Number of buckets for T5 relative position bias (only used when + ``d_kv`` is set). + relative_position_max_distance : int + Max distance for T5 relative position bucketing (only used when + ``d_kv`` is set). + """ + + PAD_TOKEN_ID = 0 + BOS_TOKEN_ID = 1 + CODEWORD_OFFSET = 2 # first codeword token id + + def __init__( + self, + codebook_sizes: tp.List[int], + hidden_units: int = 256, + num_blocks: int = 2, + num_heads: int = 1, + dropout_rate: float = 0.1, + max_length: int = 256, + initializer_range: float = 0.02, + ff_dim: tp.Optional[int] = None, + d_kv: tp.Optional[int] = None, + relative_position_num_buckets: int = 32, + relative_position_max_distance: int = 128, + ) -> None: + super().__init__() + + self.codebook_sizes = codebook_sizes + self.sid_len = len(codebook_sizes) # number of codewords per item + self.hidden_units = hidden_units + self.num_blocks = num_blocks + self.num_heads = num_heads + self.dropout_rate = dropout_rate + self.max_length = max_length + self.initializer_range = initializer_range + self.d_kv = d_kv + + # Total vocabulary: PAD + BOS + all codeword tokens + self.vocab_size = self.CODEWORD_OFFSET + sum(codebook_sizes) + + self.enc_max_len = max_length * self.sid_len + self.dec_max_len = self.sid_len + + if ff_dim is None: + ff_dim = hidden_units * 4 + + self.token_emb = nn.Embedding(self.vocab_size, hidden_units, padding_idx=self.PAD_TOKEN_ID) + + self.enc_dropout = nn.Dropout(dropout_rate) + self.dec_dropout = nn.Dropout(dropout_rate) + + self.encoder: nn.Module + self.decoder: nn.Module + if d_kv is not None: + # T5 path: relative position bias instead of absolute embeddings. + self.enc_pos_emb = None + self.dec_pos_emb = None + + self.encoder = T5Encoder( + d_model=hidden_units, + num_heads=num_heads, + d_kv=d_kv, + num_layers=num_blocks, + dim_feedforward=ff_dim, + dropout=dropout_rate, + activation="relu", + relative_position_num_buckets=relative_position_num_buckets, + relative_position_max_distance=relative_position_max_distance, + ) + self.decoder = T5Decoder( + d_model=hidden_units, + num_heads=num_heads, + d_kv=d_kv, + num_layers=num_blocks, + dim_feedforward=ff_dim, + dropout=dropout_rate, + activation="relu", + relative_position_num_buckets=relative_position_num_buckets, + relative_position_max_distance=relative_position_max_distance, + ) + else: + # Standard path: learned absolute positional embeddings. + self.enc_pos_emb = nn.Embedding(self.enc_max_len, hidden_units) + self.dec_pos_emb = nn.Embedding(self.dec_max_len + 1, hidden_units) # +1 for BOS step + + encoder_layer = nn.TransformerEncoderLayer( + d_model=hidden_units, + nhead=num_heads, + dim_feedforward=ff_dim, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=True, + ) + decoder_layer = nn.TransformerDecoderLayer( + d_model=hidden_units, + nhead=num_heads, + dim_feedforward=ff_dim, + dropout=dropout_rate, + activation="relu", + batch_first=True, + norm_first=True, + ) + + self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_blocks, norm=nn.LayerNorm(hidden_units)) + self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_blocks, norm=nn.LayerNorm(hidden_units)) + + self.output_heads = nn.ModuleList() + for cb_size in codebook_sizes: + self.output_heads.append(nn.Linear(hidden_units, cb_size)) + + self.apply(self._init_weights) + + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, T5RMSNorm): + module.weight.data.fill_(1.0) + + def _codebook_offsets(self, device: torch.device) -> torch.Tensor: + """Return tensor [sid_len] with the per-level offset into the vocab.""" + offsets = [self.CODEWORD_OFFSET] + for s in self.codebook_sizes[:-1]: + offsets.append(offsets[-1] + s) + return torch.tensor(offsets, device=device) + + def sid_to_tokens(self, sid: torch.Tensor) -> torch.Tensor: + offsets = self._codebook_offsets(sid.device) + return sid + offsets + + def tokens_to_sid(self, tokens: torch.Tensor) -> torch.Tensor: + offsets = self._codebook_offsets(tokens.device) + return tokens - offsets + + def encode( + self, + enc_input: torch.Tensor, + enc_padding_mask: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Encode the user history. + + Parameters + ---------- + enc_input : LongTensor [B, T_enc] + Vocabulary token ids for the encoder (already offset). + enc_padding_mask : BoolTensor [B, T_enc], optional + True where the token is padding. + + Returns + ------- + memory : Tensor [B, T_enc, H] + """ + x = self.token_emb(enc_input) + if self.enc_pos_emb is not None: + positions = torch.arange(enc_input.size(1), device=enc_input.device).unsqueeze(0) + x = x + self.enc_pos_emb(positions) + x = self.enc_dropout(x) + + memory = self.encoder(x, src_key_padding_mask=enc_padding_mask) + return memory + + def decode( + self, + dec_input: torch.Tensor, + memory: torch.Tensor, + enc_padding_mask: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Decode with teacher forcing. + + Parameters + ---------- + dec_input : LongTensor [B, T_dec] + Decoder input token ids (BOS followed by target tokens shifted right). + memory : Tensor [B, T_enc, H] + Encoder output. + enc_padding_mask : BoolTensor [B, T_enc], optional + Padding mask for cross-attention. + + Returns + ------- + Tensor [B, T_dec, H] + """ + x = self.token_emb(dec_input) + if self.dec_pos_emb is not None: + x = x * math.sqrt(self.hidden_units) + positions = torch.arange(dec_input.size(1), device=dec_input.device).unsqueeze(0) + x = x + self.dec_pos_emb(positions) + x = self.dec_dropout(x) + + # causal mask so position i can only attend to positions <= i + tgt_len = dec_input.size(1) + causal_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len, device=dec_input.device) + + out = self.decoder( + x, + memory, + tgt_mask=causal_mask, + memory_key_padding_mask=enc_padding_mask, + ) + return out + + def forward( + self, + enc_input: torch.Tensor, + dec_input: torch.Tensor, + enc_padding_mask: tp.Optional[torch.Tensor] = None, + ) -> tp.List[torch.Tensor]: + memory = self.encode(enc_input, enc_padding_mask) + hidden = self.decode(dec_input, memory, enc_padding_mask) # [B, sid_len, H] + + logits = [] + for d in range(self.sid_len): + logits.append(self.output_heads[d](hidden[:, d, :])) # [B, cb_size_d] + + return logits + + @torch.no_grad() + def generate_greedy( + self, + enc_input: torch.Tensor, + enc_padding_mask: tp.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + memory = self.encode(enc_input, enc_padding_mask) + B = enc_input.size(0) + device = enc_input.device + + # Start with BOS token + generated_tokens = [torch.full((B,), self.BOS_TOKEN_ID, dtype=torch.long, device=device)] + offsets = self._codebook_offsets(device) + predicted_codes = [] + + for d in range(self.sid_len): + dec_input = torch.stack(generated_tokens, dim=1) # [B, d+1] + hidden = self.decode(dec_input, memory, enc_padding_mask) # [B, d+1, H] + logits_d = self.output_heads[d](hidden[:, -1, :]) # [B, cb_size_d] + code_d = logits_d.argmax(dim=-1) # [B] + predicted_codes.append(code_d) + + # Convert raw code to vocab token for the next decoder step + token_d = code_d + offsets[d] + generated_tokens.append(token_d) + + return torch.stack(predicted_codes, dim=1) # [B, sid_len] + + @torch.no_grad() + def generate( # pylint: disable=too-many-locals + self, + enc_input: torch.Tensor, + enc_padding_mask: tp.Optional[torch.Tensor] = None, + beam_size: int = 10, + ) -> tp.Tuple[torch.Tensor, torch.Tensor]: + """ + Beam search decoding (fully batched). + + Returns + ------- + all_sids : LongTensor [B, beam_size, sid_len] + Top-``beam_size`` predicted Semantic IDs (raw codewords, 0-based). + all_scores : Tensor [B, beam_size] + Log-probability scores for each beam. + """ + memory = self.encode(enc_input, enc_padding_mask) + B = enc_input.size(0) + device = enc_input.device + offsets = self._codebook_offsets(device) + + K = beam_size + BK = B * K + + memory = memory.unsqueeze(1).expand(-1, K, -1, -1).reshape(BK, memory.size(1), memory.size(2)) + if enc_padding_mask is not None: + enc_padding_mask = enc_padding_mask.unsqueeze(1).expand(-1, K, -1).reshape(BK, -1) + + scores = torch.zeros(B, K, device=device) + scores[:, 1:] = -1e9 + + all_codes = torch.zeros(B, K, 0, dtype=torch.long, device=device) + + dec_tokens = torch.full((BK, 1), self.BOS_TOKEN_ID, dtype=torch.long, device=device) + + for d in range(self.sid_len): + hidden = self.decode(dec_tokens, memory, enc_padding_mask) + logits_d = self.output_heads[d](hidden[:, -1, :]) + log_probs = log_softmax(logits_d, dim=-1) + + cb_size = self.codebook_sizes[d] + topk_k = min(K, cb_size) + topk_vals, topk_idx = log_probs.topk(topk_k, dim=-1) + + topk_vals = topk_vals.view(B, K, topk_k) + topk_idx = topk_idx.view(B, K, topk_k) + + candidate_scores = scores.unsqueeze(-1) + topk_vals + candidate_scores = candidate_scores.view(B, K * topk_k) + candidate_idx = topk_idx.view(B, K * topk_k) + candidate_beam = ( + torch.arange(K, device=device).unsqueeze(-1).expand(-1, topk_k).reshape(1, K * topk_k).expand(B, -1) + ) + + best_scores, best_flat = candidate_scores.topk(K, dim=-1) + best_beam = candidate_beam.gather(1, best_flat) + best_code = candidate_idx.gather(1, best_flat) + + scores = best_scores + + prev_codes = all_codes.gather(1, best_beam.unsqueeze(-1).expand(-1, -1, d)) + all_codes = torch.cat([prev_codes, best_code.unsqueeze(-1)], dim=-1) + + beam_offset = torch.arange(B, device=device).unsqueeze(1) * K + flat_beam_idx = (beam_offset + best_beam).view(BK) + + dec_tokens = dec_tokens[flat_beam_idx] + new_token = (best_code + offsets[d]).view(BK, 1) + dec_tokens = torch.cat([dec_tokens, new_token], dim=1) + + return all_codes, scores diff --git a/rectools/semantic/tokenizer/__init__.py b/rectools/semantic/tokenizer/__init__.py new file mode 100644 index 00000000..0de6bd42 --- /dev/null +++ b/rectools/semantic/tokenizer/__init__.py @@ -0,0 +1,9 @@ +"""Item ID tokenizer based on RQ-VAE and R-KMeans""" + +from .emb_dataset import EmbDataset +from .model import SIDTokenizer + +__all__ = [ + "EmbDataset", + "SIDTokenizer", +] diff --git a/rectools/semantic/tokenizer/base_quantizer.py b/rectools/semantic/tokenizer/base_quantizer.py new file mode 100644 index 00000000..2b2cc10c --- /dev/null +++ b/rectools/semantic/tokenizer/base_quantizer.py @@ -0,0 +1,37 @@ +import typing as tp +from abc import ABC, abstractmethod + +import torch +from torch import nn + + +class QuantizerOutput(tp.NamedTuple): + """Quantizer outputs.""" + + sem_ids: tp.List[tp.Tuple[int, ...]] + loss: torch.Tensor + + +class Quantizer(ABC, nn.Module): + """Base class for quantizers like RQ-VAE and RK-Means.""" + + def __init__(self, input_dim: int, codebook_sizes: tp.List[int]) -> None: + super().__init__() + self.input_dim = input_dim + self.codebook_sizes = codebook_sizes + self.codebooks = nn.ModuleList() + + @abstractmethod + @torch.no_grad() + def init_codebooks(self, data: torch.Tensor) -> None: + """Initialize codebooks for the quantizer. + + Parameters + ---------- + data : torch.Tensor + Data to initialize codebooks with + """ + + @abstractmethod + def forward(self, inputs: torch.Tensor) -> QuantizerOutput: + """Quantizer forward pass.""" diff --git a/rectools/semantic/tokenizer/emb_dataset.py b/rectools/semantic/tokenizer/emb_dataset.py new file mode 100644 index 00000000..4ebb3b5b --- /dev/null +++ b/rectools/semantic/tokenizer/emb_dataset.py @@ -0,0 +1,41 @@ +import typing as tp + +import numpy as np +from torch.utils.data import Dataset + + +class EmbDataset(Dataset): + """Dataset of item IDs paired with their embeddings. + + Parameters + ---------- + item_ids : list of int + Item identifiers. + embeddings : np.ndarray + Embedding matrix of shape ``(len(item_ids), embed_dim)``. + """ + + def __init__( + self, + item_ids: tp.List[int], + embeddings: np.ndarray, + ) -> None: + if len(item_ids) != len(embeddings): + raise ValueError(f"item_ids length ({len(item_ids)}) != embeddings length ({len(embeddings)})") + self.item_ids: tp.List[int] = list(item_ids) + self.embeddings: np.ndarray = np.asarray(embeddings) + self.dim = self.embeddings.shape[-1] + + def __getitem__(self, idx: tp.Union[int, slice]) -> tp.Dict[str, tp.Any]: + if isinstance(idx, slice): + return { + "item_id": self.item_ids[idx], + "embed": self.embeddings[idx], + } + return { + "item_id": self.item_ids[idx], + "embed": self.embeddings[idx], + } + + def __len__(self) -> int: + return len(self.embeddings) diff --git a/rectools/semantic/tokenizer/loss.py b/rectools/semantic/tokenizer/loss.py new file mode 100644 index 00000000..637f5fe5 --- /dev/null +++ b/rectools/semantic/tokenizer/loss.py @@ -0,0 +1,19 @@ +import torch +from torch import nn +from torch.nn.functional import mse_loss + + +class CodebookLoss(nn.Module): + """Creates a criterion that measures codebook quantization loss between residuals and codebook embeddings.""" + + def __init__(self, mu: float = 1.0, beta: float = 0.25, reduction: str = "mean"): + super().__init__() + self.beta = beta + self.mu = mu + self.reduction = reduction + + def forward(self, y_true: torch.Tensor, y_pred: torch.Tensor) -> torch.Tensor: + """Run a forward pass.""" + return self.mu * mse_loss(y_true.detach(), y_pred, reduction=self.reduction) + self.beta * mse_loss( + y_true, y_pred.detach(), reduction=self.reduction + ) diff --git a/rectools/semantic/tokenizer/model.py b/rectools/semantic/tokenizer/model.py new file mode 100644 index 00000000..cbd22fa2 --- /dev/null +++ b/rectools/semantic/tokenizer/model.py @@ -0,0 +1,490 @@ +import os +import typing as tp +from collections import Counter, defaultdict +from datetime import datetime +from os import PathLike +from os.path import join +from pathlib import Path + +import numpy as np +import torch +from tqdm.auto import tqdm + +from rectools.semantic import OptimizerType, QuantizerType + +from .base_quantizer import Quantizer +from .emb_dataset import EmbDataset +from .rkmeans import RKmeans +from .rqvae import RQVAE + + +class SIDTokenizer: + """A class implementing semantic ID tokenizer, which can use either RK-Means or RQ-VAE as its backbone. + + The tokenizer owns the quantizer lifecycle: it constructs the appropriate + ``Quantizer`` from configuration parameters, delegates training-related + calls (``init_codebooks``, forward pass) to it, and maintains the + ``id2sid`` / ``sid2id`` vocabulary mappings. + + Parameters + ---------- + input_dim : int + Dimensionality of incoming item embeddings. + codebook_sizes : tp.List[int] + Number of codes per codebook level. + codebook_dim : int + Dimensionality of codebook embeddings (used by RQ-VAE; + ignored for RK-Means, where codebook dim equals ``input_dim``). + hidden_dims : tp.List[int] | None + Hidden layer sizes for the RQ-VAE encoder/decoder. + Pass ``None`` for RK-Means. + quantizer : QuantizerType + Which quantizer backend to use: ``"rkmeans"`` or ``"rqvae"``. + device : tp.Optional[str], optional + Device for quantizer parameters, by default None. + adapter_proj_dim : int, optional + Number of principal components for the RQ-VAE whitening adapter, + by default 512. Ignored for RK-Means. + """ + + def __init__( + self, + input_dim: int, + codebook_sizes: tp.List[int], + codebook_dim: int = 32, + hidden_dims: tp.Optional[tp.List[int]] = None, + quantizer: QuantizerType = "rkmeans", + device: tp.Optional[str] = None, + adapter_proj_dim: int = 512, + ) -> None: + self.input_dim = input_dim + self.codebook_sizes = list(codebook_sizes) + self.codebook_dim = codebook_dim + self.hidden_dims = hidden_dims + self.quantizer_name: QuantizerType = quantizer + self.device = device + self.adapter_proj_dim = adapter_proj_dim + + self._quantizer = self._build_quantizer() + self._quantizer.eval() + + # SIDs are tuples because they are hashable -> it is easier to check for conflicts + # since there can be conflicting SIDs, sid2id maps to a list + self.id2sid: tp.Dict[int, tp.Tuple[int, ...]] = {} + self.sid2id: tp.Dict[tp.Tuple[int, ...], tp.List[int]] = defaultdict(list) + + # ------------------------------------------------------------------ + # Quantizer construction + # ------------------------------------------------------------------ + + def _build_quantizer(self) -> Quantizer: + """Construct the quantizer from stored configuration and move it to ``self.device``.""" + quantizer: Quantizer + if self.quantizer_name == "rqvae": + quantizer = RQVAE( + input_dim=self.input_dim, + codebook_sizes=self.codebook_sizes, + hidden_dims=self.hidden_dims if self.hidden_dims is not None else [], + codebook_dim=self.codebook_dim, + adapter_proj_dim=self.adapter_proj_dim, + ) + else: + quantizer = RKmeans( + input_dim=self.input_dim, + codebook_sizes=self.codebook_sizes, + ) + return quantizer.to(self.device) + + # ------------------------------------------------------------------ + # Training helpers (delegated to the quantizer) + # ------------------------------------------------------------------ + + @property + def quantizer(self) -> Quantizer: + """Access the underlying quantizer module.""" + return self._quantizer + + def init_codebooks(self, dataset: EmbDataset) -> None: + """Initialize quantizer codebooks from an ``EmbDataset``. + + Parameters + ---------- + dataset : EmbDataset + Dataset whose embeddings are used for codebook initialization. + """ + embs = torch.Tensor(dataset[:]["embed"]).to(self.device) + self._quantizer.init_codebooks(embs) + + def __call__(self, batch: tp.Dict[str, tp.Any]) -> torch.Tensor: + """Run a forward pass through the quantizer on a batch and update ``id2sid``. + + Parameters + ---------- + batch : tp.Dict[str, tp.Any] + A dict with ``"item_id"`` and ``"embed"`` keys, as returned by + ``EmbDataset.__getitem__``. + + Returns + ------- + torch.Tensor + Scalar loss from the quantizer. + """ + item_ids = batch["item_id"] + if isinstance(item_ids, int): + item_ids = [item_ids] + + embs = torch.Tensor(batch["embed"]).to(self.device) + if embs.dim() == 1: + embs = embs.unsqueeze(0) + + out = self._quantizer(embs) + + for idx, item in enumerate(item_ids): + self.id2sid[item] = out.sem_ids[idx] + + return out.loss + + def train(self) -> "SIDTokenizer": + """Set the quantizer to training mode.""" + self._quantizer.train() + return self + + def eval(self) -> "SIDTokenizer": + """Set the quantizer to evaluation mode.""" + self._quantizer.eval() + return self + + def parameters(self) -> tp.Any: + """Return quantizer parameters (for optimizer construction).""" + return self._quantizer.parameters() + + # ------------------------------------------------------------------ + # Fit + # ------------------------------------------------------------------ + + @staticmethod + def _get_collision_rate(id2sid: tp.Dict[int, tp.Tuple[int, ...]]) -> float: + """Fraction of items whose SID collides with at least one other item.""" + sid_counts = Counter(id2sid.values()) + n_colliding = sum(c for c in sid_counts.values() if c > 1) + return n_colliding / len(id2sid) if id2sid else 1.0 + + @torch.no_grad() + def _get_id2sid(self, dataset: EmbDataset, batch_size: int = 2048) -> tp.Dict[int, tp.Tuple[int, ...]]: + """Run the quantizer over ``dataset`` in eval mode and return an id2sid mapping.""" + was_training = self._quantizer.training + self._quantizer.eval() + id2sid: tp.Dict[int, tp.Tuple[int, ...]] = {} + for start_idx in range(0, len(dataset), batch_size): + end_idx = min(len(dataset), start_idx + batch_size) + batch = dataset[start_idx:end_idx] + batch_embeds = torch.Tensor(batch["embed"]).to(self.device) + out = self._quantizer(batch_embeds) + for idx, item in enumerate(batch["item_id"]): + id2sid[item] = out.sem_ids[idx] + if was_training: + self._quantizer.train() + return id2sid + + def fit( # pylint: disable=too-many-locals + self, + item_ids: tp.List[int], + embeddings: np.ndarray, + init_max_items: int = 10_000, + max_epochs: int = 1_000, + patience: int = 100, + optimizer: OptimizerType = "adamw", + lr: float = 1e-3, + weight_decay: float = 1e-4, + batch_size: int = 2048, + save_dir: str = "checkpoints", + ) -> "SIDTokenizer": + """Train the quantizer on pre-computed item embeddings. + + After training, ``id2sid`` is populated for every item in ``item_ids``. + + Parameters + ---------- + item_ids : tp.List[int] + Item identifiers matching rows of ``embeddings``. + embeddings : np.ndarray + Embedding matrix of shape ``(len(item_ids), embed_dim)``. + init_max_items : int, optional + Maximum number of items used for codebook initialization, + by default 10_000. + max_epochs : int, optional + Maximum training epochs, by default 1_000. + patience : int, optional + Early stopping patience (epochs without collision rate + improvement), by default 100. + optimizer : OptimizerType, optional + Optimizer name, by default ``"adamw"``. + lr : float, optional + Learning rate, by default 1e-3. + weight_decay : float, optional + Weight decay, by default 1e-4. + batch_size : int, optional + Training batch size, by default 2048. + save_dir : str, optional + Directory for saving best checkpoint, by default ``"checkpoints"``. + + Returns + ------- + SIDTokenizer + ``self``, with trained quantizer and populated ``id2sid``. + """ + dataset = EmbDataset(item_ids, embeddings) + + # Initialize codebooks from a (possibly smaller) subset + init_dataset = EmbDataset(item_ids[:init_max_items], embeddings[:init_max_items]) + self.init_codebooks(init_dataset) + + opt: torch.optim.Optimizer + if optimizer.lower() == "adam": + opt = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay) + elif optimizer.lower() == "adagrad": + opt = torch.optim.Adagrad(self.parameters(), lr=lr, weight_decay=weight_decay) + else: + opt = torch.optim.AdamW(self.parameters(), lr=lr, weight_decay=weight_decay) + + best_collision_rate = float("inf") + steps_no_improve = 0 + best_path = Path( + join( + save_dir, + f"tokenizer_{self.quantizer_name}", + f"{datetime.now().strftime('%d%b%Y-%H.%M.%S.%f')}", + "best_perf.pt", + ) + ) + os.makedirs(best_path.parent, exist_ok=True) + + self.train() + with tqdm( + total=max_epochs * (len(dataset) // batch_size + (len(dataset) % batch_size > 0)), + ) as pbar: + for epoch_no in range(max_epochs): + pbar.set_description(f"Tokenizer train {epoch_no} / {max_epochs}") + for start_idx in range(0, len(dataset), batch_size): + end_idx = min(start_idx + batch_size, len(dataset)) + batch = dataset[start_idx:end_idx] + batch_embeds = torch.Tensor(batch["embed"]).to(self.device) + out = self._quantizer(batch_embeds) + loss = out.loss + opt.zero_grad() + loss.backward() + opt.step() + + metrics_dict = {"loss": loss.item()} + pbar.set_postfix({metric: f"{value:.4f}" for metric, value in metrics_dict.items()}) + pbar.update(1) + + # Compute collision rate + id2sid = self._get_id2sid(dataset, batch_size=batch_size) + collision_rate = self._get_collision_rate(id2sid) + + improved = collision_rate < best_collision_rate + if improved: + best_collision_rate = collision_rate + steps_no_improve = 0 + torch.save(self._quantizer.state_dict(), best_path) + else: + steps_no_improve += 1 + + self.train() + + if steps_no_improve >= patience: + print(f"Early stopping at epoch {epoch_no} (patience exceeded {patience})") + break + + print(f"Best model's collision rate: {best_collision_rate * 100:.2f}%") + self._quantizer.load_state_dict(torch.load(best_path, map_location=self.device)) + self.eval() + self.id2sid = self._get_id2sid(dataset, batch_size=batch_size) + return self + + # ------------------------------------------------------------------ + # Vocabulary extension + # ------------------------------------------------------------------ + + @torch.no_grad() + def extend(self, dataset: EmbDataset) -> None: + """Tokenize new items from an EmbDataset and add them to the vocabulary. + + Items already present in id2sid are skipped. + """ + for start_idx in range(0, len(dataset), 256): + end_idx = min(start_idx + 256, len(dataset)) + batch = dataset[start_idx:end_idx] + + item_ids = batch["item_id"] + if isinstance(item_ids, int): + item_ids = [item_ids] + + new_mask = [i for i, item in enumerate(item_ids) if item not in self.id2sid] + if not new_mask: + continue + + new_embs = torch.Tensor(batch["embed"]).to(self.device) + if new_embs.dim() == 1: + new_embs = new_embs.unsqueeze(0) + new_embs = new_embs[new_mask] + + out = self._quantizer(new_embs) + sids = out.sem_ids + for j, mask_idx in enumerate(new_mask): + item = item_ids[mask_idx] + sid = sids[j] + self.id2sid[item] = sid + if self.sid2id: + self.sid2id[sid].append(item) + + # ------------------------------------------------------------------ + # Tokenize / decode + # ------------------------------------------------------------------ + + def _get_sid(self, item: int) -> tp.Tuple[int, ...]: + if item in self.id2sid: + return self.id2sid[item] + raise ValueError( + f"Item with ID {item} is out of vocabulary. " "Use extend() with an EmbDataset containing this item first." + ) + + def _get_id( + self, + sid: tp.Tuple[int, ...], + default_value: tp.Optional[int] = None, + rng: tp.Optional[np.random.RandomState] = None, + ) -> tp.Optional[int]: + if len(self.sid2id) == 0: + for item_id, item_sid in self.id2sid.items(): + self.sid2id[item_sid].append(item_id) + if sid in self.sid2id: + if rng is None: + rng = np.random + return int(rng.choice(self.sid2id[sid])) + return default_value + + def tokenize( # pylint: disable=redefined-builtin + self, input: int | tp.Iterable[int] + ) -> tp.Tuple[int, ...] | tp.List[tp.Tuple[int, ...]]: + """Convert incoming item IDs to corresponding semantic IDs. + + Parameters + ---------- + input : int | tp.Iterable[int] + Item ID(s). + + Returns + ------- + tp.Tuple[int, ...] | tp.List[tp.Tuple[int, ...]] + Item SID(s). + """ + if isinstance(input, int): + return self._get_sid(input) + + output = [] + for item in input: + output.append(self._get_sid(item)) + return output + + def decode( # pylint: disable=redefined-builtin + self, + input: tp.Tuple[int, ...] | tp.Iterable[tp.Tuple[int, ...]], + default_value: tp.Optional[int] = None, + rng: tp.Optional[np.random.RandomState] = None, + ) -> tp.Optional[int] | tp.List[tp.Optional[int]]: + """Decode incoming SID(s) to corresponding item IDs. + + Parameters + ---------- + input : tp.Tuple[int, ...] | tp.Iterable[tp.Tuple[int, ...]] + Either a single SID or an iterable of SIDs. + default_value : tp.Optional[int], optional + What to return when there is no item corresponding to given SID, by default None + rng : np.random.RandomState, optional + Random generator used to resolve SID collisions. If omitted, the + global NumPy random generator is used. + + Returns + ------- + tp.Optional[int] | tp.List[tp.Optional[int]] + Item ID(s) corresponding to input SID(s). + """ + # check if the input is only a single SID + if isinstance(input, tuple) and all(map(lambda x: isinstance(x, int), input)): + return self._get_id(input, default_value=default_value, rng=rng) + # type: ignore[arg-type] + return list(map(lambda item: self._get_id(item, default_value=default_value, rng=rng), input)) + + # ------------------------------------------------------------------ + # Save / load + # ------------------------------------------------------------------ + + def save(self, path: tp.Union[str, PathLike[str]]) -> None: + """Save tokenizer to the specified path. + + The checkpoint is a dict containing quantizer configuration, + quantizer ``state_dict``, and the ``id2sid`` vocabulary mapping. + + Parameters + ---------- + path : str or PathLike + File path to write the checkpoint to. + """ + checkpoint = { + "quantizer_name": self.quantizer_name, + "input_dim": self.input_dim, + "codebook_sizes": self.codebook_sizes, + "codebook_dim": self.codebook_dim, + "hidden_dims": self.hidden_dims, + "adapter_proj_dim": self.adapter_proj_dim, + "quantizer_state_dict": self._quantizer.state_dict(), + "id2sid": self.id2sid, + } + torch.save(checkpoint, path) + + @classmethod + def load( + cls, + path: tp.Union[str, PathLike[str]], + device: tp.Optional[str] = None, + map_location: tp.Optional[str] = None, + ) -> "SIDTokenizer": + """Load serialized SIDTokenizer. + + Parameters + ---------- + path : str or PathLike + Location of the serialized tokenizer. + device : tp.Optional[str], optional + Device to load tokenizer's quantizer to, by default None. + map_location : tp.Optional[str], optional + ``map_location`` forwarded to ``torch.load``. If provided, + also used as the device. By default None. + + Returns + ------- + SIDTokenizer + A de-serialized tokenizer. + """ + effective_device = map_location or device + checkpoint = torch.load(path, map_location=effective_device, weights_only=False) + + tokenizer = cls( + input_dim=checkpoint["input_dim"], + codebook_sizes=checkpoint["codebook_sizes"], + codebook_dim=checkpoint["codebook_dim"], + hidden_dims=checkpoint["hidden_dims"], + quantizer=checkpoint["quantizer_name"], + device=effective_device, + adapter_proj_dim=checkpoint["adapter_proj_dim"], + ) + tokenizer._quantizer.load_state_dict(checkpoint["quantizer_state_dict"]) + tokenizer._quantizer.eval() + tokenizer.id2sid = checkpoint["id2sid"] + + return tokenizer + + def __len__(self) -> int: + return len(set(self.id2sid.values())) diff --git a/rectools/semantic/tokenizer/rkmeans.py b/rectools/semantic/tokenizer/rkmeans.py new file mode 100644 index 00000000..18cc3019 --- /dev/null +++ b/rectools/semantic/tokenizer/rkmeans.py @@ -0,0 +1,171 @@ +import typing as tp + +import numpy as np +import torch +from k_means_constrained import KMeansConstrained +from torch import nn +from torch.nn.functional import mse_loss, normalize +from tqdm.auto import tqdm + +from .base_quantizer import Quantizer, QuantizerOutput + + +class Codebook(nn.Module): + """An RK-Means codebook. Takes in item embeddings or their residuals from previous codebooks and outputs + codes corresponding to embeddings closest to input residuals. + + Parameters + ---------- + code_dim : int + Size of codebook embeddings. + n_codes : int + Number of codebook embeddings. + """ + + def __init__( + self, + code_dim: int, + n_codes: int, + ) -> None: + super().__init__() + self.code_dim = code_dim + self.n_codes = n_codes + self.code_embs = nn.Embedding(n_codes, code_dim) + self.kmeans_initted_ = False + + def init_from_centroids(self, centroids: torch.Tensor) -> None: + """Initialize codebook using predetermined centroids as embeddings. + + Parameters + ---------- + centroids : torch.Tensor + precomputed centroids of shape [n_codes, emb_dim] + """ + assert centroids.shape == (self.n_codes, self.code_dim) + with torch.no_grad(): + self.code_embs.weight.data.copy_(centroids) + self.kmeans_initted_ = True + + def forward(self, inputs: torch.Tensor) -> tp.Tuple: + """Run a forward pass. + + Parameters + ---------- + inputs : torch.Tensor + Input residuals. + + Returns + ------- + tp.Tuple + Corresponding codes, quantized reprezentations, residuals between quantized + representations and inputs. + """ + diffs = inputs.unsqueeze(1) - self.code_embs.weight.unsqueeze(0) + dists = torch.linalg.vector_norm(diffs, dim=-1) # pylint: disable=not-callable + codes = dists.argmin(dim=-1) + + quantized = self.code_embs.weight[codes, :] + residuals = inputs - quantized + + # Straight-through estimator: forward value is quantized, + # but gradients pass through to the encoder as if quantized == inputs. + quantized = inputs + (quantized - inputs).detach() + + return codes, quantized, residuals + + +class RKmeans(Quantizer): + """Apply Residual Mini-Batch K-Means (RK-Means) -- a multi-level vector quantizer + that applies quantization on residuals to generate a tuple of codewords (aka Semantic IDs). + + Parameters + ---------- + input_dim : int + Input embedding dimension. + codebook_sizes : tp.List[int] + Number of embeddings contained in RK-Means codebooks. + """ + + def __init__( + self, + input_dim: int, + codebook_sizes: tp.List[int], + ) -> None: + super().__init__( + input_dim=input_dim, + codebook_sizes=codebook_sizes, + ) + for n_codes in codebook_sizes: + self.codebooks.append(Codebook(self.input_dim, n_codes)) + + @torch.no_grad() + def init_codebooks(self, data: torch.Tensor) -> None: + """Initialize codebook centroids using constrained k-means. + + Runs the encoder on `data`, then for each codebook level fits + KMeansConstrained (with size_min per cluster) on the residuals + and copies the resulting centroids into the codebook embeddings. + + Parameters + ---------- + data : torch.Tensor + Input embeddings. + """ + residuals = normalize(data).cpu().numpy() + + for layer in tqdm(self.codebooks, desc="Initializing codebooks with constrained k-means"): + size_min = min(len(data) // (layer.n_codes * 2), 50) + size_max = size_min * 4 if layer.n_codes * size_min * 4 > len(data) else len(data) + km = KMeansConstrained( + n_clusters=layer.n_codes, + size_min=size_min, + size_max=size_max, + random_state=0, + max_iter=10, + n_init=10, + n_jobs=10, + verbose=False, + ) + km.fit(residuals) + centroids = torch.from_numpy(km.cluster_centers_).to(data.device) + layer.init_from_centroids(centroids) + + # Compute residuals for the next level + assignments = km.labels_ + residuals = residuals - km.cluster_centers_[assignments] + residuals = residuals / (np.linalg.norm(residuals, axis=1, keepdims=True) + 1e-12) + + def forward(self, inputs: torch.Tensor) -> QuantizerOutput: + """Run a forward pass + + Parameters + ---------- + inputs : torch.Tensor + Input item embeddings. + + Returns + ------- + tp.Tuple + Semantic IDs corresponding to input embeddings, reconstruction loss. + """ + res_prev = normalize(inputs) + embeds, sem_ids = [], [] + for layer in self.codebooks: + codes, quantized, res = layer(res_prev) + res_prev = res + res_prev = normalize(res_prev) + embeds.append(quantized) + sem_ids.append(codes) + + # embeds: [B] x num_codebooks -> [B, num_codebooks] -> [B] + embeds_tensor = torch.stack(embeds, dim=1).sum(dim=1) + sem_ids_tensor = torch.stack(sem_ids, dim=1).to(torch.int32) + + # Convert semids from tensors to tuples of ints + sem_ids = [ + tuple(sem_ids_tensor[idx, :].flatten().detach().cpu().tolist()) for idx in range(sem_ids_tensor.size(0)) + ] + + # Reconstruction loss + loss = mse_loss(inputs, embeds_tensor, reduction="mean") + return QuantizerOutput(sem_ids=sem_ids, loss=loss) diff --git a/rectools/semantic/tokenizer/rqvae.py b/rectools/semantic/tokenizer/rqvae.py new file mode 100644 index 00000000..7ae37d57 --- /dev/null +++ b/rectools/semantic/tokenizer/rqvae.py @@ -0,0 +1,304 @@ +import typing as tp + +import torch +from k_means_constrained import KMeansConstrained +from torch import nn +from torch.nn.functional import mse_loss +from tqdm.auto import tqdm + +from rectools.semantic.modules import MLP + +from .base_quantizer import Quantizer, QuantizerOutput +from .loss import CodebookLoss + + +def _get_pca_projection(embeddings: torch.Tensor, out_dim: int) -> torch.Tensor: + centered = embeddings - embeddings.mean(dim=0) + _, _, vh = torch.linalg.svd(centered, full_matrices=True) # pylint: disable=not-callable + return vh[:out_dim].T.contiguous() + + +class WhiteningAdapter(nn.Module): + r"""Applies an adapter consisting of PCA projection and an optional linear transformation to incoming item + embeddings. + + The adapter itself is a linear transformation that modifies input embedddings `x` in the following manner: + + .. math:: + \text{out}(x) = W \times (x - b) + + Where :math:`W` is weights matrix initialized as principal component matrix and :math:`b` is a vector of + biases initialized as mean of embedding. + + Parameters + ---------- + adapter_type : tp.Literal["ffn", "identity"], optional + Transformation applied after PCA projection, by default "identity" + emb_dim : int, optional + Dimension of incoming embeddings, by default 768 + proj_dim : int, optional + How many principal components to use, by default 512 + hidden_units : int, optional + Size of hidden layer in linear adapter, by default 256 + dropout : float, optional + Dropout probability for linear adapter, by default 0.1 + device : str, optional + Device the adapter is on, by default "cpu" + """ + + def __init__( + self, + adapter_type: tp.Literal["ffn", "identity"] = "identity", + emb_dim: int = 768, + proj_dim: int = 512, + hidden_units: int = 256, + dropout: float = 0.1, + device: str = "cpu", + ) -> None: + super().__init__() + + self.adapter_type = adapter_type + self.hidden_units = hidden_units + self.emb_dim = emb_dim + self.proj_dim = proj_dim + self.dropout = dropout + self.device = device + self.weight = nn.Parameter(torch.empty(self.emb_dim, self.proj_dim)) + self.bias = nn.Parameter(torch.empty(self.emb_dim)) + + self.head: nn.Module + if self.adapter_type == "ffn": + self.head = MLP(self.proj_dim, [self.hidden_units], self.proj_dim, dropout=self.dropout) + else: + self.head = nn.Identity() + + def freeze_adapter(self) -> None: + """Freeze PCA transform's weights and biases.""" + self.bias.requires_grad = False + self.weight.requires_grad = False + + @torch.no_grad() + def init_from_embeds(self, embeddings: torch.Tensor) -> None: + """Initialize PCA projection weights using item metadata embeddigns. + + Parameters + ---------- + embeddings : torch.Tensor + Item metadata embeddings + """ + proj = _get_pca_projection(embeddings, self.proj_dim) + + self.bias.data.copy_(embeddings.mean(dim=0)) + self.weight.data.copy_(proj) + self.freeze_adapter() + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """Adapter forward pass.""" + # PCA projection is precision-sensitive (1024-dim matmul accumulates + # bf16 rounding errors). Force fp32 even inside autocast. + with torch.amp.autocast(self.device, enabled=False): + projected = (inputs.float() - self.bias.float()) @ self.weight.float() + return self.head(projected) + + +class Codebook(nn.Module): + """An RQ-VAE codebook. Takes in RQ-VAE encoder outputs or their residuals from previous codebooks and outputs + codes corresponding to embeddings closest to input residuals. + + Parameters + ---------- + code_dim : int + Size of codebook embeddings. + n_codes : int + Number of codebook embeddings. + """ + + def __init__( + self, + code_dim: int, + n_codes: int, + ) -> None: + super().__init__() + self.code_dim = code_dim + self.n_codes = n_codes + self.code_embs = nn.Embedding(n_codes, code_dim) + self.cb_loss = CodebookLoss() + self.kmeans_initted_ = False + + def init_from_centroids(self, centroids: torch.Tensor) -> None: + """Initialize codebook using predetermined centroids as embeddings. + + Parameters + ---------- + centroids : torch.Tensor + precomputed centroids of shape [n_codes, emb_dim] + """ + assert centroids.shape == (self.n_codes, self.code_dim) + with torch.no_grad(): + self.code_embs.weight.data.copy_(centroids) + self.kmeans_initted_ = True + + def forward(self, inputs: torch.Tensor) -> tp.Tuple: + """Codebook forward pass. + + Parameters + ---------- + inputs : torch.Tensor + Input residuals expected to be of shape [B, emb_dim] + + Returns + ------- + tp.Tuple + Codes corresponding to embeddings closest to input residuals, embeddings themselves, resuduals, + and codebook loss. + """ + # inputs: [B, emb_dim] -> [B, 1, emb_dim] + # code_embs: [cb_size, emb_dim] -> [1, cb_size, emb_dim] + # This maps to diffs: [B, cb_size, emb_dim] + diffs = inputs.unsqueeze(1) - self.code_embs.weight.unsqueeze(0) + dists = torch.linalg.vector_norm(diffs, dim=-1) # pylint: disable=not-callable + codes = dists.argmin(dim=-1) # out: [B] + + quantized = self.code_embs(codes) + residuals = inputs - quantized + + # Straight-through estimator: forward value is quantized, + # but gradients pass through to the encoder as if quantized == inputs. + quantized_st = inputs + (quantized - inputs).detach() + + loss = self.cb_loss(inputs, quantized) + + return codes, quantized_st, residuals, loss + + +class RQVAE(Quantizer): + """Apply Residual-Quantized Variational AutoEncoder (RQ-VAE) -- a multi-level vector quantizer + that applies quantization on residuals to generate a tuple of codewords (aka Semantic IDs). + The Autoencoder is jointly trained by updating the quantization codebook and the DNN + encoder-decoder parameters. + + Parameters + ---------- + input_dim : int + Input embeddings size. + codebook_dim : int + Size of codebook embeddings. + hidden_dims : tp.List[int] + Encoder and decoder hidden layer sizes. + codebook_sizes : tp.List[int] + Number of embeddings each codebook contains. + adapter_type : tp.Literal["ffn", "identity"], optional + Transformation applied after PCA projection, by default "identity" + adapter_proj_dim : int, optional + How many principal components to use, by default 512 + """ + + def __init__( + self, + input_dim: int, + codebook_dim: int, + hidden_dims: tp.List[int], + codebook_sizes: tp.List[int], + adapter_type: tp.Literal["ffn", "identity"] = "identity", + adapter_proj_dim: int = 512, + ) -> None: + super().__init__( + input_dim=input_dim, + codebook_sizes=codebook_sizes, + ) + self.codebook_dim = codebook_dim + for n_codes in codebook_sizes: + self.codebooks.append(Codebook(codebook_dim, n_codes)) + + self.adapter_type = adapter_type + self.adapter_proj_dim = adapter_proj_dim + self.adapter = WhiteningAdapter( + adapter_type=adapter_type, + emb_dim=input_dim, + proj_dim=adapter_proj_dim, + ) + + self.encoder = MLP(adapter_proj_dim, hidden_dims, codebook_dim) + self.decoder = MLP(codebook_dim, hidden_dims[::-1], adapter_proj_dim) + + @torch.no_grad() + def init_codebooks(self, data: torch.Tensor) -> None: + """Initialize codebook centroids using constrained k-means. + + Runs the encoder on `data`, then for each codebook level fits + KMeansConstrained (with size_min per cluster) on the residuals + and copies the resulting centroids into the codebook embeddings. + + Parameters + ---------- + data : torch.Tensor + Item metadata embeddings. + """ + self.adapter.init_from_embeds(data) + embs = self.adapter(data) + encoded = self.encoder(embs).cpu().numpy() + residuals = encoded + + for layer in tqdm(self.codebooks, desc="Initializing codebooks with constrained k-means"): + size_min = min(len(data) // (layer.n_codes * 2), 50) + size_max = size_min * 4 if layer.n_codes * size_min * 4 > len(data) else len(data) + km = KMeansConstrained( + n_clusters=layer.n_codes, + size_min=size_min, + size_max=size_max, + random_state=0, + max_iter=10, + n_init=10, + n_jobs=10, + verbose=False, + ) + km.fit(residuals) + centroids = torch.from_numpy(km.cluster_centers_).to(data.device) + layer.init_from_centroids(centroids) + + # Compute residuals for the next level + assignments = km.labels_ + residuals = residuals - km.cluster_centers_[assignments] + + def forward(self, inputs: torch.Tensor) -> QuantizerOutput: + """RQ-VAE forward pass. + + Parameters + ---------- + inputs : torch.Tensor + Input embeddings + + Returns + ------- + tp.Tuple + Semantic IDs corresponding to input embeddings, RQ-VAE loss (codebook losses + + reconstruction loss). + """ + inputs_adapted = self.adapter(inputs) + inputs_enc = self.encoder(inputs_adapted) + + loss = torch.tensor(0.0, device=next(self.codebooks.parameters()).device) + + res_prev = inputs_enc + embeds, sem_ids = [], [] + for layer in self.codebooks: + codes, quantized_st, res, loss_upd = layer(res_prev) + res_prev = res + loss += loss_upd + embeds.append(quantized_st) + sem_ids.append(codes) + + embeds_tensor = torch.stack(embeds, dim=1).sum(dim=1) + sem_ids_tensor = torch.stack(sem_ids, dim=1).to(torch.int32) + + # Convert semids from tensors to tuples of ints + sem_ids = [ + tuple(sem_ids_tensor[idx, :].flatten().detach().cpu().tolist()) for idx in range(sem_ids_tensor.size(0)) + ] + + embeds_dec = self.decoder(embeds_tensor) + + # Reconstruction loss + loss += mse_loss(inputs_adapted, embeds_dec, reduction="mean") + + return QuantizerOutput(sem_ids=sem_ids, loss=loss) diff --git a/setup.cfg b/setup.cfg index cbd220b9..d1b23834 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,6 +49,9 @@ per-file-ignores = rectools/models/nn/dssm.py: D101,D102,N812 rectools/dataset/torch_datasets.py: D101,D102 rectools/models/implicit_als.py: N806 + rectools/semantic/modules/transformer_blocks.py: N806 + rectools/semantic/tiger/module.py: D101,D102,N806 + rectools/semantic/tiger/lightning.py: D101,D102 rectools/fast_transformers/net.py: N806 rectools/fast_transformers/unisrec/net.py: N806 diff --git a/tests/semantic/__init__.py b/tests/semantic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/semantic/data_handling/__init__.py b/tests/semantic/data_handling/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/semantic/data_handling/test_dataset.py b/tests/semantic/data_handling/test_dataset.py new file mode 100644 index 00000000..bafa7f75 --- /dev/null +++ b/tests/semantic/data_handling/test_dataset.py @@ -0,0 +1,170 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pandas as pd +import pytest +from pytorch_lightning import seed_everything + +from rectools.semantic.data_handling.dataset import PaddingCollateFn, TIGERDataset +from rectools.semantic.tokenizer.emb_dataset import EmbDataset +from rectools.semantic.tokenizer.model import SIDTokenizer + + +@pytest.fixture +def trained_tokenizer() -> SIDTokenizer: + seed_everything(42, workers=True) + rng = np.random.RandomState(42) + item_ids = list(range(1, 21)) + embeddings = rng.randn(20, 16).astype(np.float32) + dataset = EmbDataset(item_ids, embeddings) + tok = SIDTokenizer( + codebook_dim=16, + hidden_dims=None, + codebook_sizes=[4, 4], + input_dim=16, + device="cpu", + quantizer="rkmeans", + ) + tok.init_codebooks(dataset) + batch = dataset[:] + tok(batch) + return tok + + +@pytest.fixture +def interactions() -> pd.DataFrame: + return pd.DataFrame( + { + "user_id": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], + "item_id": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + "timestamp": list(range(1, 13)), + } + ) + + +class TestTIGERDataset: # pylint: disable=redefined-outer-name + def test_eval_mode_len(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None: + ds = TIGERDataset( + interactions=interactions, + tokenizer=trained_tokenizer, + codebook_sizes=[4, 4], + max_length=10, + codeword_offset=2, + bos_token_id=1, + only_last=True, + eval_mode=True, + ) + # 3 users, each gets one sample + assert len(ds) == 3 + + def test_eval_mode_getitem(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None: + ds = TIGERDataset( + interactions=interactions, + tokenizer=trained_tokenizer, + codebook_sizes=[4, 4], + max_length=10, + codeword_offset=2, + bos_token_id=1, + only_last=True, + eval_mode=True, + ) + sample = ds[0] + assert "input_ids" in sample + assert "labels" in sample + assert isinstance(sample["input_ids"], np.ndarray) + assert isinstance(sample["labels"], (int, np.integer)) + + def test_train_mode_only_last(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None: + ds = TIGERDataset( + interactions=interactions, + tokenizer=trained_tokenizer, + codebook_sizes=[4, 4], + max_length=10, + codeword_offset=2, + bos_token_id=1, + only_last=True, + eval_mode=False, + ) + assert len(ds) == 3 + sample = ds[0] + assert "input_ids" in sample + assert "dec_input" in sample + assert "labels" in sample + + def test_train_mode_all_subsequences(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None: + ds = TIGERDataset( + interactions=interactions, + tokenizer=trained_tokenizer, + codebook_sizes=[4, 4], + max_length=10, + codeword_offset=2, + bos_token_id=1, + only_last=False, + eval_mode=False, + ) + # Each user has 4 items -> 3 subsequences per user -> 9 total + assert len(ds) == 9 + + def test_max_users(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None: + ds = TIGERDataset( + interactions=interactions, + tokenizer=trained_tokenizer, + codebook_sizes=[4, 4], + max_length=10, + codeword_offset=2, + bos_token_id=1, + only_last=True, + eval_mode=True, + max_users=2, + ) + assert len(ds) == 2 + + def test_dec_input_starts_with_bos(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None: + ds = TIGERDataset( + interactions=interactions, + tokenizer=trained_tokenizer, + codebook_sizes=[4, 4], + max_length=10, + codeword_offset=2, + bos_token_id=1, + only_last=True, + eval_mode=False, + ) + sample = ds[0] + assert sample["dec_input"][0] == 1 # BOS token + + +class TestPaddingCollateFn: + def test_pads_sequences(self) -> None: + batch = [ + {"input_ids": np.array([1, 2, 3]), "labels": np.array([10, 11])}, + {"input_ids": np.array([4, 5]), "labels": np.array([12, 13])}, + ] + collate = PaddingCollateFn(padding_value=0, labels_padding_value=-100) + result = collate(batch) + assert result["input_ids"].shape == (2, 3) + assert result["labels"].shape == (2, 2) + # Second sequence should be padded + assert result["input_ids"][1, 2].item() == 0 + + def test_scalar_labels(self) -> None: + batch = [ + {"input_ids": np.array([1, 2, 3]), "labels": 10}, + {"input_ids": np.array([4, 5, 6]), "labels": 20}, + ] + collate = PaddingCollateFn(padding_value=0) + result = collate(batch) + assert result["labels"].shape == (2,) + assert result["labels"].tolist() == [10, 20] diff --git a/tests/semantic/data_handling/test_k_core.py b/tests/semantic/data_handling/test_k_core.py new file mode 100644 index 00000000..956b6ec8 --- /dev/null +++ b/tests/semantic/data_handling/test_k_core.py @@ -0,0 +1,87 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd +import pytest + +from rectools.semantic.data_handling.k_core import k_core + + +class TestKCore: + @pytest.fixture + def interactions(self) -> pd.DataFrame: + return pd.DataFrame( + { + "user_id": [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5], + "item_id": [10, 20, 30, 10, 20, 30, 10, 20, 30, 10, 20, 30, 10, 20, 30], + } + ) + + def test_no_filtering_needed(self, interactions: pd.DataFrame) -> None: + result = k_core(interactions, user_min_interactions=3, item_min_interactions=3) + assert len(result) == len(interactions) + + def test_filters_users_below_threshold(self) -> None: + df = pd.DataFrame( + { + "user_id": [1, 1, 1, 2], + "item_id": [10, 20, 30, 10], + } + ) + result = k_core(df, user_min_interactions=2, item_min_interactions=1) + assert set(result["user_id"].unique()) == {1} + + def test_filters_items_below_threshold(self) -> None: + df = pd.DataFrame( + { + "user_id": [1, 1, 2, 2, 3, 3], + "item_id": [10, 20, 10, 20, 10, 30], + } + ) + result = k_core(df, user_min_interactions=1, item_min_interactions=2) + assert 30 not in result["item_id"].values + + def test_iterative_filtering(self) -> None: + # User 3 has only item 30, but item 30 only has user 3 + # Removing item 30 removes user 3 interactions, then user 3 has 0 + df = pd.DataFrame( + { + "user_id": [1, 1, 2, 2, 3], + "item_id": [10, 20, 10, 20, 30], + } + ) + result = k_core(df, user_min_interactions=2, item_min_interactions=2) + assert 3 not in result["user_id"].values + assert 30 not in result["item_id"].values + assert len(result) == 4 + + def test_custom_column_names(self) -> None: + df = pd.DataFrame( + { + "uid": [1, 1, 2, 2], + "iid": [10, 20, 10, 20], + } + ) + result = k_core(df, user_min_interactions=2, item_min_interactions=2, user_col="uid", item_col="iid") + assert len(result) == 4 + + def test_empty_result(self) -> None: + df = pd.DataFrame( + { + "user_id": [1, 2, 3], + "item_id": [10, 20, 30], + } + ) + result = k_core(df, user_min_interactions=2, item_min_interactions=2) + assert len(result) == 0 diff --git a/tests/semantic/data_handling/test_loo_split.py b/tests/semantic/data_handling/test_loo_split.py new file mode 100644 index 00000000..4d42415b --- /dev/null +++ b/tests/semantic/data_handling/test_loo_split.py @@ -0,0 +1,89 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd + +from rectools.semantic.data_handling.loo_split import loo_split + + +class TestLooSplit: + def test_basic_split(self) -> None: + df = pd.DataFrame( + { + "user_id": [1, 1, 1, 2, 2, 2], + "item_id": [10, 20, 30, 40, 50, 60], + "timestamp": [1, 2, 3, 4, 5, 6], + } + ) + train, val, test = loo_split(df) + # Train: all but last 2 per user + assert len(train) == 2 # 1 per user + # Val: all but last 1 per user + assert len(val) == 4 # 2 per user + # Test: all interactions + assert len(test) == 6 + + def test_users_with_fewer_than_3_interactions_dropped(self) -> None: + df = pd.DataFrame( + { + "user_id": [1, 1, 1, 2, 2], + "item_id": [10, 20, 30, 40, 50], + "timestamp": [1, 2, 3, 4, 5], + } + ) + train, val, test = loo_split(df) + # User 2 has only 2 interactions, should be dropped + assert 2 not in train["user_id"].values + assert 2 not in val["user_id"].values + assert 2 not in test["user_id"].values + + def test_sorts_by_timestamp(self) -> None: + df = pd.DataFrame( + { + "user_id": [1, 1, 1], + "item_id": [30, 10, 20], + "timestamp": [3, 1, 2], + } + ) + train, val, test = loo_split(df) + # After sorting by timestamp: items are [10, 20, 30] + # Train = first item, val = first 2 items, test = all 3 + assert train["item_id"].tolist() == [10] + assert set(val["item_id"].tolist()) == {10, 20} + assert set(test["item_id"].tolist()) == {10, 20, 30} + + def test_no_timestamp_column(self) -> None: + df = pd.DataFrame( + { + "user_id": [1, 1, 1], + "item_id": [10, 20, 30], + } + ) + train, val, test = loo_split(df) + assert len(train) == 1 + assert len(val) == 2 + assert len(test) == 3 + + def test_all_users_dropped(self) -> None: + df = pd.DataFrame( + { + "user_id": [1, 1, 2, 2], + "item_id": [10, 20, 30, 40], + "timestamp": [1, 2, 3, 4], + } + ) + train, val, test = loo_split(df) + assert len(train) == 0 + assert len(val) == 0 + assert len(test) == 0 diff --git a/tests/semantic/modules/__init__.py b/tests/semantic/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/semantic/modules/test_mlp.py b/tests/semantic/modules/test_mlp.py new file mode 100644 index 00000000..b7e9a709 --- /dev/null +++ b/tests/semantic/modules/test_mlp.py @@ -0,0 +1,59 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from pytorch_lightning import seed_everything + +from rectools.semantic.modules.mlp import MLP + + +class TestMLP: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + def test_output_shape(self) -> None: + mlp = MLP(input_dim=16, hidden_dims=[8], out_dim=4) + x = torch.randn(5, 16) + out = mlp(x) + assert out.shape == (5, 4) + + def test_no_hidden_dims(self) -> None: + mlp = MLP(input_dim=16, hidden_dims=[], out_dim=4) + x = torch.randn(5, 16) + out = mlp(x) + assert out.shape == (5, 4) + + def test_multiple_hidden_dims(self) -> None: + mlp = MLP(input_dim=32, hidden_dims=[16, 8], out_dim=4) + x = torch.randn(3, 32) + out = mlp(x) + assert out.shape == (3, 4) + + def test_with_dropout(self) -> None: + mlp = MLP(input_dim=16, hidden_dims=[8], out_dim=4, dropout=0.5) + x = torch.randn(5, 16) + out = mlp(x) + assert out.shape == (5, 4) + + def test_with_normalize(self) -> None: + mlp = MLP(input_dim=16, hidden_dims=[8], out_dim=4, normalize=True) + x = torch.randn(5, 16) + out = mlp(x) + assert out.shape == (5, 4) + + def test_invalid_input_dim_raises(self) -> None: + mlp = MLP(input_dim=16, hidden_dims=[8], out_dim=4) + with pytest.raises(AssertionError, match="Invalid input dim"): + mlp(torch.randn(5, 32)) diff --git a/tests/semantic/modules/test_transformer_blocks.py b/tests/semantic/modules/test_transformer_blocks.py new file mode 100644 index 00000000..878e7eef --- /dev/null +++ b/tests/semantic/modules/test_transformer_blocks.py @@ -0,0 +1,134 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from pytorch_lightning import seed_everything + +from rectools.semantic.modules.transformer_blocks import ( + T5Attention, + T5DecoderLayer, + T5EncoderLayer, + T5RMSNorm, + T5RelativePositionBias, + init_feed_forward, +) + + +class TestT5RMSNorm: + def test_output_shape(self) -> None: + norm = T5RMSNorm(d_model=16) + x = torch.randn(3, 5, 16) + out = norm(x) + assert out.shape == (3, 5, 16) + + def test_normalized_rms(self) -> None: + norm = T5RMSNorm(d_model=16) + x = torch.randn(3, 5, 16) + out = norm(x) + # Output should have roughly unit RMS (before learned scaling) + assert torch.isfinite(out).all() + + +class TestT5RelativePositionBias: + def test_output_shape(self) -> None: + bias = T5RelativePositionBias(num_heads=4, bidirectional=True) + out = bias(query_len=5, key_len=5, device=torch.device("cpu")) + assert out.shape == (1, 4, 5, 5) + + def test_causal_output_shape(self) -> None: + bias = T5RelativePositionBias(num_heads=2, bidirectional=False) + out = bias(query_len=3, key_len=3, device=torch.device("cpu")) + assert out.shape == (1, 2, 3, 3) + + def test_asymmetric_shape(self) -> None: + bias = T5RelativePositionBias(num_heads=2, bidirectional=True) + out = bias(query_len=3, key_len=7, device=torch.device("cpu")) + assert out.shape == (1, 2, 3, 7) + + +class TestT5Attention: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + def test_self_attention_output_shape(self) -> None: + attn = T5Attention(d_model=16, num_heads=2, d_kv=8) + x = torch.randn(2, 5, 16) + out, _ = attn(x, x, x) + assert out.shape == (2, 5, 16) + + def test_cross_attention_output_shape(self) -> None: + attn = T5Attention(d_model=16, num_heads=2, d_kv=8, kv_input_dim=32) + q = torch.randn(2, 3, 16) + kv = torch.randn(2, 7, 32) + out, _ = attn(q, kv, kv) + assert out.shape == (2, 3, 16) + + def test_with_position_bias(self) -> None: + attn = T5Attention(d_model=16, num_heads=2, d_kv=8) + x = torch.randn(2, 5, 16) + pos_bias = torch.randn(1, 2, 5, 5) + out, _ = attn(x, x, x, position_bias=pos_bias) + assert out.shape == (2, 5, 16) + + +class TestT5EncoderLayer: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + def test_output_shape(self) -> None: + layer = T5EncoderLayer(d_model=16, num_heads=2, d_kv=8, dim_feedforward=32) + x = torch.randn(2, 5, 16) + out = layer(x) + assert out.shape == (2, 5, 16) + + def test_with_padding_mask(self) -> None: + layer = T5EncoderLayer(d_model=16, num_heads=2, d_kv=8) + x = torch.randn(2, 5, 16) + mask = torch.tensor([[False, False, False, True, True], [False, False, True, True, True]]) + out = layer(x, src_key_padding_mask=mask) + assert out.shape == (2, 5, 16) + + +class TestT5DecoderLayer: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + def test_output_shape(self) -> None: + layer = T5DecoderLayer(d_model=16, num_heads=2, d_kv=8, dim_feedforward=32) + tgt = torch.randn(2, 3, 16) + memory = torch.randn(2, 5, 16) + out = layer(tgt, memory) + assert out.shape == (2, 3, 16) + + +class TestInitFeedForward: + def test_relu(self) -> None: + ff = init_feed_forward(16, 4, 0.1, "relu") + out = ff(torch.randn(2, 5, 16)) + assert out.shape == (2, 5, 16) + + def test_gelu(self) -> None: + ff = init_feed_forward(16, 4, 0.1, "gelu") + out = ff(torch.randn(2, 5, 16)) + assert out.shape == (2, 5, 16) + + def test_swiglu(self) -> None: + ff = init_feed_forward(16, 4, 0.1, "swiglu") + out = ff(torch.randn(2, 5, 16)) + assert out.shape == (2, 5, 16) + + def test_unsupported_activation_raises(self) -> None: + with pytest.raises(ValueError, match="Unsupported"): + init_feed_forward(16, 4, 0.1, "tanh") diff --git a/tests/semantic/test_metrics.py b/tests/semantic/test_metrics.py new file mode 100644 index 00000000..8f38729c --- /dev/null +++ b/tests/semantic/test_metrics.py @@ -0,0 +1,75 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from rectools.semantic.metrics import coverage_k, gini_k + + +class TestCoverageK: + def test_all_items_covered(self) -> None: + all_topk = [[1, 2, 3], [4, 5]] + result = coverage_k(all_topk, k=3, num_items=5) + assert result == 1.0 + + def test_partial_coverage(self) -> None: + all_topk = [[1, 2], [1, 3]] + result = coverage_k(all_topk, k=2, num_items=5) + assert result == pytest.approx(3 / 5) + + def test_k_limits_items(self) -> None: + all_topk = [[1, 2, 3, 4, 5]] + result = coverage_k(all_topk, k=2, num_items=5) + assert result == pytest.approx(2 / 5) + + def test_empty_lists(self) -> None: + result = coverage_k([], k=5, num_items=10) + assert result == 0.0 + + def test_duplicate_items_across_users(self) -> None: + all_topk = [[1, 2], [2, 3], [3, 1]] + result = coverage_k(all_topk, k=2, num_items=3) + assert result == 1.0 + + +class TestGiniK: + def test_uniform_distribution(self) -> None: + all_topk = [[1], [2], [3], [4]] + result = gini_k(all_topk, k=1) + assert result == pytest.approx(0.0) + + def test_concentrated_distribution(self) -> None: + all_topk = [[1], [1], [1], [1]] + result = gini_k(all_topk, k=1) + assert result == pytest.approx(0.0) + + def test_empty_input(self) -> None: + result = gini_k([], k=5) + assert result == 0.0 + + def test_k_limits_items(self) -> None: + all_topk = [[1, 99], [2, 99]] + result_k1 = gini_k(all_topk, k=1) + result_k2 = gini_k(all_topk, k=2) + # With k=1 we only see items 1, 2 (uniform) -> gini=0 + assert result_k1 == pytest.approx(0.0) + # With k=2, item 99 appears twice but 1 and 2 appear once each + assert result_k2 > 0.0 + + def test_skewed_distribution(self) -> None: + # Item 1 recommended 3 times, item 2 once -> should be positive gini + all_topk = [[1], [1], [1], [2]] + result = gini_k(all_topk, k=1) + assert result > 0.0 + assert result < 1.0 diff --git a/tests/semantic/tiger/__init__.py b/tests/semantic/tiger/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/semantic/tiger/test_lightning.py b/tests/semantic/tiger/test_lightning.py new file mode 100644 index 00000000..535b6987 --- /dev/null +++ b/tests/semantic/tiger/test_lightning.py @@ -0,0 +1,106 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch +from pytorch_lightning import seed_everything + +from rectools.semantic.tiger.lightning import TIGERLightning +from rectools.semantic.tiger.module import TIGERNet +from rectools.semantic.tokenizer.emb_dataset import EmbDataset +from rectools.semantic.tokenizer.model import SIDTokenizer + + +@pytest.fixture +def trained_tokenizer() -> SIDTokenizer: + seed_everything(42, workers=True) + rng = np.random.RandomState(42) + item_ids = list(range(1, 21)) + embeddings = rng.randn(20, 16).astype(np.float32) + dataset = EmbDataset(item_ids, embeddings) + tok = SIDTokenizer( + codebook_dim=16, + hidden_dims=None, + codebook_sizes=[4, 4], + input_dim=16, + device="cpu", + quantizer="rkmeans", + ) + tok.init_codebooks(dataset) + batch = dataset[:] + tok(batch) + return tok + + +class TestTIGERLightning: # pylint: disable=redefined-outer-name + def setup_method(self) -> None: + seed_everything(42, workers=True) + + @pytest.fixture + def lightning_model(self, trained_tokenizer: SIDTokenizer) -> TIGERLightning: + net = TIGERNet( + codebook_sizes=[4, 4], + hidden_units=16, + num_blocks=1, + num_heads=2, + dropout_rate=0.0, + max_length=10, + ) + return TIGERLightning( + model=net, + tokenizer=trained_tokenizer, + beam_size=4, + top_k=3, + num_items=20, + ) + + def test_training_step(self, lightning_model: TIGERLightning) -> None: + batch = { + "input_ids": torch.randint(2, 10, (4, 6)), + "dec_input": torch.randint(0, 10, (4, 2)), + "labels": torch.randint(0, 4, (4, 2)), + } + loss = lightning_model.training_step(batch, batch_idx=0) + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 + assert torch.isfinite(loss) + + def test_configure_optimizers_adamw(self, lightning_model: TIGERLightning) -> None: + result = lightning_model.configure_optimizers() + assert isinstance(result, torch.optim.AdamW) + + def test_configure_optimizers_with_schedule(self, trained_tokenizer: SIDTokenizer) -> None: + net = TIGERNet(codebook_sizes=[4, 4], hidden_units=16, num_blocks=1, num_heads=2) + model = TIGERLightning( + model=net, + tokenizer=trained_tokenizer, + lr_schedule="cosine", + warmup_steps=10, + max_iters=100, + ) + result = model.configure_optimizers() + assert isinstance(result, dict) + assert "optimizer" in result + assert "lr_scheduler" in result + + def test_configure_optimizers_unknown_raises(self, trained_tokenizer: SIDTokenizer) -> None: + net = TIGERNet(codebook_sizes=[4, 4], hidden_units=16, num_blocks=1, num_heads=2) + model = TIGERLightning( + model=net, + tokenizer=trained_tokenizer, + optimizer_name="unknown", # type: ignore[arg-type] + ) + with pytest.raises(ValueError, match="Unknown optimizer"): + model.configure_optimizers() diff --git a/tests/semantic/tiger/test_loss.py b/tests/semantic/tiger/test_loss.py new file mode 100644 index 00000000..0369332d --- /dev/null +++ b/tests/semantic/tiger/test_loss.py @@ -0,0 +1,55 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from rectools.semantic.tiger.loss import compute_tiger_loss + + +class TestComputeTigerLoss: + def test_correct_predictions_low_loss(self) -> None: + # Create logits that strongly predict the correct labels + batch_size = 4 + codebook_size = 8 + sid_len = 3 + labels = torch.randint(0, codebook_size, (batch_size, sid_len)) + logits = [] + for d in range(sid_len): + logit = torch.full((batch_size, codebook_size), -10.0) + for i in range(batch_size): + logit[i, labels[i, d]] = 10.0 + logits.append(logit) + loss = compute_tiger_loss(logits, labels) + assert loss.item() < 0.1 + + def test_random_predictions_higher_loss(self) -> None: + batch_size = 4 + codebook_size = 8 + sid_len = 3 + labels = torch.randint(0, codebook_size, (batch_size, sid_len)) + logits = [torch.randn(batch_size, codebook_size) for _ in range(sid_len)] + loss = compute_tiger_loss(logits, labels) + assert loss.item() > 0.1 + + def test_output_is_scalar(self) -> None: + logits = [torch.randn(2, 4), torch.randn(2, 4)] + labels = torch.randint(0, 4, (2, 2)) + loss = compute_tiger_loss(logits, labels) + assert loss.dim() == 0 + + def test_ignores_minus_100_labels(self) -> None: + logits = [torch.randn(2, 4)] + labels = torch.tensor([[-100], [1]]) + loss = compute_tiger_loss(logits, labels) + assert torch.isfinite(loss) diff --git a/tests/semantic/tiger/test_model.py b/tests/semantic/tiger/test_model.py new file mode 100644 index 00000000..683497d2 --- /dev/null +++ b/tests/semantic/tiger/test_model.py @@ -0,0 +1,245 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import numpy as np +import pandas as pd +import pytest +import torch +from pytorch_lightning import seed_everything + +from rectools.semantic.tiger.model import TIGERModel +from rectools.semantic.tokenizer.emb_dataset import EmbDataset +from rectools.semantic.tokenizer.model import SIDTokenizer + + +@pytest.fixture +def trained_tokenizer() -> SIDTokenizer: + seed_everything(42, workers=True) + rng = np.random.RandomState(42) + item_ids = list(range(1, 51)) + embeddings = rng.randn(50, 16).astype(np.float32) + dataset = EmbDataset(item_ids, embeddings) + tok = SIDTokenizer( + codebook_dim=16, + hidden_dims=None, + codebook_sizes=[4, 4], + input_dim=16, + device="cpu", + quantizer="rkmeans", + ) + tok.init_codebooks(dataset) + batch = dataset[:] + tok(batch) + return tok + + +@pytest.fixture +def interactions() -> pd.DataFrame: + rows = [] + for user in range(1, 11): + for t, item in enumerate(range(1, 6)): + rows.append([user, item, t]) + return pd.DataFrame(rows, columns=["user_id", "item_id", "timestamp"]) + + +class TestTIGERModel: # pylint: disable=redefined-outer-name + def setup_method(self) -> None: + seed_everything(42, workers=True) + + def test_init(self, trained_tokenizer: SIDTokenizer) -> None: + model = TIGERModel( + tokenizer=trained_tokenizer, + hidden_units=16, + num_blocks=1, + num_heads=2, + max_length=10, + device="cpu", + ) + assert model.hidden_units == 16 + assert model.num_blocks == 1 + + def test_predict_output_format(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None: + model = TIGERModel( + tokenizer=trained_tokenizer, + hidden_units=16, + num_blocks=1, + num_heads=2, + max_length=10, + device="cpu", + beam_size=5, + top_k=3, + ) + result = model.predict(interactions, top_k=3) + assert isinstance(result, pd.DataFrame) + assert set(result.columns) == {"user_id", "item_id", "score", "rank"} + # Each user should have at most top_k recommendations + for _uid, group in result.groupby("user_id"): + assert len(group) <= 3 + + def test_save_and_load(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None: + model = TIGERModel( + tokenizer=trained_tokenizer, + hidden_units=16, + num_blocks=1, + num_heads=2, + max_length=10, + device="cpu", + ) + with tempfile.TemporaryDirectory() as tmpdir: + model.save(tmpdir) + assert os.path.exists(os.path.join(tmpdir, "tokenizer.pt")) + assert os.path.exists(os.path.join(tmpdir, "model.pt")) + assert os.path.exists(os.path.join(tmpdir, "config.json")) + + loaded = TIGERModel.load(tmpdir, device="cpu") + assert loaded.hidden_units == model.hidden_units + assert loaded.num_blocks == model.num_blocks + assert loaded.num_heads == model.num_heads + + def test_predict_batching(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None: + model = TIGERModel( + tokenizer=trained_tokenizer, + hidden_units=16, + num_blocks=1, + num_heads=2, + max_length=10, + device="cpu", + beam_size=5, + top_k=3, + eval_batch_size=2, + ) + result = model.predict(interactions, top_k=3) + assert isinstance(result, pd.DataFrame) + assert set(result.columns) == {"user_id", "item_id", "score", "rank"} + result_user_ids = set(result["user_id"].unique()) + input_user_ids = set(interactions["user_id"].unique()) + assert result_user_ids.issubset(input_user_ids) + for _uid, group in result.groupby("user_id"): + assert list(group["rank"]) == list(range(1, len(group) + 1)) + assert len(group) <= 3 + + def test_fit_runs(self, trained_tokenizer: SIDTokenizer, interactions: pd.DataFrame) -> None: + torch.use_deterministic_algorithms(True) + model = TIGERModel( + tokenizer=trained_tokenizer, + hidden_units=16, + num_blocks=1, + num_heads=2, + max_length=10, + max_epochs=1, + batch_size=16, + eval_batch_size=16, + num_workers=0, + beam_size=4, + top_k=3, + device="cpu", + ) + # Split interactions for train/val + train_df = interactions[interactions["user_id"] <= 7] + val_df = interactions[interactions["user_id"] > 7] + model.fit(train_df, val_df) + # After fit, model should be able to predict + result = model.predict(interactions, top_k=3) + assert len(result) > 0 + torch.use_deterministic_algorithms(False) + + def test_evaluate_uses_unique_item_count_for_coverage( + self, + trained_tokenizer: SIDTokenizer, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + model = TIGERModel( + tokenizer=trained_tokenizer, + hidden_units=16, + num_blocks=1, + num_heads=2, + max_length=10, + device="cpu", + ) + interactions = pd.DataFrame( + [ + [1, 10, 1], + [1, 20, 2], + [1, 30, 3], + [2, 10, 1], + [2, 30, 2], + [2, 50, 3], + ], + columns=["user_id", "item_id", "timestamp"], + ) + captured = {} + + class FakeTrainer: + def test(self, lightning_model, dataloaders): + captured["num_items"] = lightning_model.num_items + return [{"Coverage@10": 0.0}] + + monkeypatch.setattr("rectools.semantic.tiger.model.pl.Trainer", lambda: FakeTrainer()) + + model.evaluate(interactions) + + assert captured["num_items"] == 4 + + def test_predict_is_reproducible_with_random_seed(self, trained_tokenizer: SIDTokenizer) -> None: + collision_sid = (0, 0) + trained_tokenizer.id2sid = { + 10: collision_sid, + 20: collision_sid, + 30: (1, 1), + } + trained_tokenizer.sid2id.clear() + + interactions = pd.DataFrame( + [ + [1, 10, 1], + [1, 30, 2], + ], + columns=["user_id", "item_id", "timestamp"], + ) + + first_model = TIGERModel( + tokenizer=trained_tokenizer, + hidden_units=16, + num_blocks=1, + num_heads=2, + max_length=10, + device="cpu", + random_seed=42, + ) + second_model = TIGERModel( + tokenizer=trained_tokenizer, + hidden_units=16, + num_blocks=1, + num_heads=2, + max_length=10, + device="cpu", + random_seed=42, + ) + + generated_codes = torch.tensor([[[0, 0], [0, 0], [1, 1]]], dtype=torch.long) + generated_scores = torch.tensor([[0.9, 0.8, 0.7]], dtype=torch.float32) + + def fake_generate(*args, **kwargs): + return generated_codes, generated_scores + + first_model.model.generate = fake_generate # type: ignore[method-assign] + second_model.model.generate = fake_generate # type: ignore[method-assign] + + first_result = first_model.predict(interactions, top_k=3) + second_result = second_model.predict(interactions, top_k=3) + + assert first_result.equals(second_result) diff --git a/tests/semantic/tiger/test_module.py b/tests/semantic/tiger/test_module.py new file mode 100644 index 00000000..cc118c57 --- /dev/null +++ b/tests/semantic/tiger/test_module.py @@ -0,0 +1,126 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from pytorch_lightning import seed_everything + +from rectools.semantic.tiger.module import TIGERNet + + +class TestTIGERNetStandard: + """Tests for TIGERNet with standard nn.Transformer (d_kv=None).""" + + def setup_method(self) -> None: + seed_everything(42, workers=True) + + @pytest.fixture + def model(self) -> TIGERNet: + return TIGERNet( + codebook_sizes=[4, 4], + hidden_units=16, + num_blocks=1, + num_heads=2, + dropout_rate=0.0, + max_length=10, + ) + + def test_forward_shapes(self, model: TIGERNet) -> None: + batch_size = 3 + enc_len = 4 # 2 items * 2 codewords + dec_len = 2 # sid_len = 2 + enc_input = torch.randint(2, 10, (batch_size, enc_len)) + dec_input = torch.randint(0, 10, (batch_size, dec_len)) + logits = model(enc_input, dec_input) + assert len(logits) == 2 + assert logits[0].shape == (batch_size, 4) + assert logits[1].shape == (batch_size, 4) + + def test_generate_greedy_shape(self, model: TIGERNet) -> None: + model.eval() + enc_input = torch.randint(2, 10, (2, 4)) + codes = model.generate_greedy(enc_input) + assert codes.shape == (2, 2) # (batch_size, sid_len) + + def test_generate_beam_shape(self, model: TIGERNet) -> None: + model.eval() + enc_input = torch.randint(2, 10, (2, 4)) + codes, scores = model.generate(enc_input, beam_size=3) + assert codes.shape == (2, 3, 2) # (batch, beam, sid_len) + assert scores.shape == (2, 3) + + def test_generate_with_padding_mask(self, model: TIGERNet) -> None: + model.eval() + enc_input = torch.randint(2, 10, (2, 6)) + enc_input[1, 4:] = TIGERNet.PAD_TOKEN_ID + mask = enc_input == TIGERNet.PAD_TOKEN_ID + codes, _scores = model.generate(enc_input, enc_padding_mask=mask, beam_size=2) + assert codes.shape == (2, 2, 2) + + def test_sid_to_tokens_roundtrip(self, model: TIGERNet) -> None: + sid = torch.tensor([[0, 1], [2, 3]]) + tokens = model.sid_to_tokens(sid) + recovered = model.tokens_to_sid(tokens) + assert torch.equal(sid, recovered) + + def test_codes_in_valid_range(self, model: TIGERNet) -> None: + model.eval() + enc_input = torch.randint(2, 10, (3, 4)) + codes = model.generate_greedy(enc_input) + for d in range(model.sid_len): + assert (codes[:, d] >= 0).all() + assert (codes[:, d] < model.codebook_sizes[d]).all() + + +class TestTIGERNetT5: + """Tests for TIGERNet with T5-style layers (d_kv set).""" + + def setup_method(self) -> None: + seed_everything(42, workers=True) + + @pytest.fixture + def model(self) -> TIGERNet: + return TIGERNet( + codebook_sizes=[4, 4], + hidden_units=16, + num_blocks=1, + num_heads=2, + d_kv=8, + dropout_rate=0.0, + max_length=10, + ) + + def test_forward_shapes(self, model: TIGERNet) -> None: + enc_input = torch.randint(2, 10, (2, 4)) + dec_input = torch.randint(0, 10, (2, 2)) + logits = model(enc_input, dec_input) + assert len(logits) == 2 + assert logits[0].shape == (2, 4) + + def test_generate_greedy(self, model: TIGERNet) -> None: + model.eval() + enc_input = torch.randint(2, 10, (2, 4)) + codes = model.generate_greedy(enc_input) + assert codes.shape == (2, 2) + + def test_generate_beam(self, model: TIGERNet) -> None: + model.eval() + enc_input = torch.randint(2, 10, (2, 4)) + codes, scores = model.generate(enc_input, beam_size=3) + assert codes.shape == (2, 3, 2) + assert scores.shape == (2, 3) + + def test_no_positional_embeddings(self, model: TIGERNet) -> None: + assert model.enc_pos_emb is None + assert model.dec_pos_emb is None diff --git a/tests/semantic/tokenizer/__init__.py b/tests/semantic/tokenizer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/semantic/tokenizer/test_emb_dataset.py b/tests/semantic/tokenizer/test_emb_dataset.py new file mode 100644 index 00000000..5e6a196b --- /dev/null +++ b/tests/semantic/tokenizer/test_emb_dataset.py @@ -0,0 +1,50 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +from rectools.semantic.tokenizer.emb_dataset import EmbDataset + + +class TestEmbDataset: + @pytest.fixture + def dataset(self) -> EmbDataset: + item_ids = [10, 20, 30, 40] + embeddings = np.random.RandomState(42).randn(4, 8).astype(np.float32) + return EmbDataset(item_ids, embeddings) + + def test_len(self, dataset: EmbDataset) -> None: + assert len(dataset) == 4 + + def test_dim(self, dataset: EmbDataset) -> None: + assert dataset.dim == 8 + + def test_getitem_int(self, dataset: EmbDataset) -> None: + item = dataset[0] + assert item["item_id"] == 10 + assert item["embed"].shape == (8,) + + def test_getitem_slice(self, dataset: EmbDataset) -> None: + batch = dataset[0:2] + assert batch["item_id"] == [10, 20] + assert batch["embed"].shape == (2, 8) + + def test_getitem_full_slice(self, dataset: EmbDataset) -> None: + batch = dataset[:] + assert len(batch["item_id"]) == 4 + + def test_mismatched_lengths_raises(self) -> None: + with pytest.raises(ValueError, match="item_ids length"): + EmbDataset([1, 2, 3], np.zeros((4, 8))) diff --git a/tests/semantic/tokenizer/test_loss.py b/tests/semantic/tokenizer/test_loss.py new file mode 100644 index 00000000..eb506d9f --- /dev/null +++ b/tests/semantic/tokenizer/test_loss.py @@ -0,0 +1,47 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from rectools.semantic.tokenizer.loss import CodebookLoss + + +class TestCodebookLoss: + def test_zero_when_equal(self) -> None: + loss_fn = CodebookLoss() + x = torch.randn(4, 8) + result = loss_fn(x, x.clone()) + assert result.item() == 0.0 + + def test_positive_when_different(self) -> None: + loss_fn = CodebookLoss() + y_true = torch.randn(4, 8) + y_pred = torch.randn(4, 8) + result = loss_fn(y_true, y_pred) + assert result.item() > 0.0 + + def test_custom_coefficients(self) -> None: + loss_default = CodebookLoss(mu=1.0, beta=0.25) + loss_custom = CodebookLoss(mu=2.0, beta=0.5) + y_true = torch.randn(4, 8) + y_pred = torch.randn(4, 8) + r_default = loss_default(y_true, y_pred) + r_custom = loss_custom(y_true, y_pred) + # Custom has doubled coefficients, so loss should roughly double + assert r_custom.item() > r_default.item() + + def test_output_is_scalar(self) -> None: + loss_fn = CodebookLoss() + result = loss_fn(torch.randn(4, 8), torch.randn(4, 8)) + assert result.dim() == 0 diff --git a/tests/semantic/tokenizer/test_model.py b/tests/semantic/tokenizer/test_model.py new file mode 100644 index 00000000..adb27897 --- /dev/null +++ b/tests/semantic/tokenizer/test_model.py @@ -0,0 +1,229 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +import numpy as np +import pytest +import torch +from pytorch_lightning import seed_everything + +from rectools.semantic.tokenizer.emb_dataset import EmbDataset +from rectools.semantic.tokenizer.model import SIDTokenizer + + +class TestSIDTokenizerRKmeans: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + @pytest.fixture + def dataset(self) -> EmbDataset: + rng = np.random.RandomState(42) + item_ids = list(range(100)) + embeddings = rng.randn(100, 16).astype(np.float32) + return EmbDataset(item_ids, embeddings) + + @pytest.fixture + def tokenizer(self, dataset: EmbDataset) -> SIDTokenizer: + tok = SIDTokenizer( + codebook_dim=16, + hidden_dims=None, + codebook_sizes=[8, 8], + input_dim=16, + device="cpu", + quantizer="rkmeans", + ) + tok.init_codebooks(dataset) + # Run a forward pass to populate id2sid + batch = dataset[:] + tok(batch) + return tok + + def test_init(self) -> None: + tok = SIDTokenizer( + codebook_dim=16, + hidden_dims=None, + codebook_sizes=[8, 8], + input_dim=16, + quantizer="rkmeans", + ) + assert tok.quantizer_name == "rkmeans" + assert tok.codebook_sizes == [8, 8] + + def test_forward_returns_loss(self, dataset: EmbDataset) -> None: + tok = SIDTokenizer( + codebook_dim=16, + hidden_dims=None, + codebook_sizes=[8, 8], + input_dim=16, + device="cpu", + quantizer="rkmeans", + ) + tok.init_codebooks(dataset) + batch = dataset[0:10] + loss = tok(batch) + assert isinstance(loss, torch.Tensor) + assert loss.dim() == 0 + assert loss.item() >= 0 + + def test_forward_populates_id2sid(self, dataset: EmbDataset) -> None: + tok = SIDTokenizer( + codebook_dim=16, + hidden_dims=None, + codebook_sizes=[8, 8], + input_dim=16, + device="cpu", + quantizer="rkmeans", + ) + tok.init_codebooks(dataset) + batch = dataset[0:10] + tok(batch) + assert len(tok.id2sid) == 10 + + def test_tokenize_single(self, tokenizer: SIDTokenizer) -> None: + sid = tokenizer.tokenize(0) + assert isinstance(sid, tuple) + assert len(sid) == 2 + + def test_tokenize_multiple(self, tokenizer: SIDTokenizer) -> None: + sids = tokenizer.tokenize([0, 1, 2]) + assert isinstance(sids, list) + assert len(sids) == 3 + for sid in sids: + assert isinstance(sid, tuple) + assert len(sid) == 2 + + def test_tokenize_unknown_item_raises(self, tokenizer: SIDTokenizer) -> None: + with pytest.raises(ValueError, match="out of vocabulary"): + tokenizer.tokenize(9999) + + def test_decode_single(self, tokenizer: SIDTokenizer) -> None: + sid = tokenizer.tokenize(5) + item_id = tokenizer.decode(sid) + # The decoded item should exist in the vocabulary + assert item_id is not None + + def test_decode_multiple(self, tokenizer: SIDTokenizer) -> None: + sids = tokenizer.tokenize([0, 1, 2]) + decoded = tokenizer.decode(sids) + assert isinstance(decoded, list) + assert len(decoded) == 3 + + def test_decode_unknown_sid(self, tokenizer: SIDTokenizer) -> None: + unknown_sid = (999, 999) + result = tokenizer.decode(unknown_sid) + assert result is None + + def test_decode_with_default_value(self, tokenizer: SIDTokenizer) -> None: + unknown_sid = (999, 999) + result = tokenizer.decode(unknown_sid, default_value=-1) + assert result == -1 + + def test_decode_collision_is_reproducible_with_rng(self, tokenizer: SIDTokenizer) -> None: + tokenizer.id2sid = { + 10: (1, 1), + 20: (1, 1), + 30: (2, 2), + } + tokenizer.sid2id.clear() + + first_rng = np.random.RandomState(42) + second_rng = np.random.RandomState(42) + + first = tokenizer.decode([(1, 1), (1, 1), (1, 1)], rng=first_rng) + second = tokenizer.decode([(1, 1), (1, 1), (1, 1)], rng=second_rng) + + assert first == second + + def test_len(self, tokenizer: SIDTokenizer) -> None: + assert len(tokenizer) > 0 + assert len(tokenizer) <= 100 + + def test_extend(self, tokenizer: SIDTokenizer) -> None: + rng = np.random.RandomState(99) + new_ids = list(range(100, 120)) + new_embeddings = rng.randn(20, 16).astype(np.float32) + new_dataset = EmbDataset(new_ids, new_embeddings) + tokenizer.extend(new_dataset) + # New items should be in id2sid + for item_id in new_ids: + assert item_id in tokenizer.id2sid + + def test_extend_skips_existing(self, tokenizer: SIDTokenizer) -> None: + old_sid = tokenizer.id2sid[0] + rng = np.random.RandomState(99) + # Include existing item 0 + ids = [0, 200] + embs = rng.randn(2, 16).astype(np.float32) + ds = EmbDataset(ids, embs) + tokenizer.extend(ds) + # Item 0 should still have the same SID + assert tokenizer.id2sid[0] == old_sid + assert 200 in tokenizer.id2sid + + def test_save_and_load(self, tokenizer: SIDTokenizer) -> None: + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "tokenizer.pt") + tokenizer.save(path) + loaded = SIDTokenizer.load(path, map_location="cpu") + + assert loaded.codebook_sizes == tokenizer.codebook_sizes + assert loaded.input_dim == tokenizer.input_dim + assert loaded.codebook_dim == tokenizer.codebook_dim + assert set(loaded.id2sid.keys()) == set(tokenizer.id2sid.keys()) + + # Check that tokenization produces same results + sid_orig = tokenizer.tokenize(5) + sid_loaded = loaded.tokenize(5) + assert sid_orig == sid_loaded + + +class TestSIDTokenizerRQVAE: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + @pytest.fixture + def dataset(self) -> EmbDataset: + rng = np.random.RandomState(42) + item_ids = list(range(100)) + embeddings = rng.randn(100, 16).astype(np.float32) + return EmbDataset(item_ids, embeddings) + + def test_init_rqvae(self) -> None: + tok = SIDTokenizer( + codebook_dim=8, + hidden_dims=[8], + codebook_sizes=[4, 4], + input_dim=16, + quantizer="rqvae", + ) + assert tok.quantizer_name == "rqvae" + + def test_forward_rqvae(self, dataset: EmbDataset) -> None: + tok = SIDTokenizer( + codebook_dim=8, + hidden_dims=[8], + codebook_sizes=[4, 4], + input_dim=16, + device="cpu", + quantizer="rqvae", + adapter_proj_dim=8, + ) + tok.init_codebooks(dataset) + batch = dataset[0:10] + loss = tok(batch) + assert isinstance(loss, torch.Tensor) + assert torch.isfinite(loss) + assert len(tok.id2sid) == 10 diff --git a/tests/semantic/tokenizer/test_rkmeans.py b/tests/semantic/tokenizer/test_rkmeans.py new file mode 100644 index 00000000..5fc11358 --- /dev/null +++ b/tests/semantic/tokenizer/test_rkmeans.py @@ -0,0 +1,79 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from pytorch_lightning import seed_everything + +from rectools.semantic.tokenizer.rkmeans import Codebook, RKmeans + + +class TestCodebook: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + def test_output_shapes(self) -> None: + cb = Codebook(code_dim=8, n_codes=16) + inputs = torch.randn(4, 8) + codes, quantized, residuals = cb(inputs) + assert codes.shape == (4,) + assert quantized.shape == (4, 8) + assert residuals.shape == (4, 8) + + def test_codes_in_range(self) -> None: + cb = Codebook(code_dim=8, n_codes=16) + inputs = torch.randn(10, 8) + codes, _, _ = cb(inputs) + assert (codes >= 0).all() + assert (codes < 16).all() + + def test_init_from_centroids(self) -> None: + cb = Codebook(code_dim=4, n_codes=3) + centroids = torch.randn(3, 4) + cb.init_from_centroids(centroids) + assert cb.kmeans_initted_ + assert torch.allclose(cb.code_embs.weight.data, centroids) + + +class TestRKmeans: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + def test_forward_shapes(self) -> None: + model = RKmeans(input_dim=16, codebook_sizes=[8, 8]) + inputs = torch.randn(10, 16) + sem_ids, loss = model(inputs) + assert len(sem_ids) == 10 + assert all(len(sid) == 2 for sid in sem_ids) + assert loss.dim() == 0 + + def test_loss_is_positive(self) -> None: + model = RKmeans(input_dim=16, codebook_sizes=[8, 8]) + inputs = torch.randn(10, 16) + _, loss = model(inputs) + assert loss.item() > 0.0 + + def test_init_codebooks(self) -> None: + model = RKmeans(input_dim=16, codebook_sizes=[4, 4]) + data = torch.randn(100, 16) + model.init_codebooks(data) + for layer in model.codebooks: + assert layer.kmeans_initted_ + + def test_sem_ids_are_tuples(self) -> None: + model = RKmeans(input_dim=8, codebook_sizes=[4, 4, 4]) + inputs = torch.randn(5, 8) + sem_ids, _ = model(inputs) + for sid in sem_ids: + assert isinstance(sid, tuple) + assert len(sid) == 3 diff --git a/tests/semantic/tokenizer/test_rqvae.py b/tests/semantic/tokenizer/test_rqvae.py new file mode 100644 index 00000000..a1b4223a --- /dev/null +++ b/tests/semantic/tokenizer/test_rqvae.py @@ -0,0 +1,106 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from pytorch_lightning import seed_everything + +from rectools.semantic.tokenizer.rqvae import RQVAE, WhiteningAdapter + + +class TestWhiteningAdapter: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + def test_output_shape(self) -> None: + adapter = WhiteningAdapter(adapter_type="identity", emb_dim=16, proj_dim=8) + inputs = torch.randn(5, 16) + adapter.init_from_embeds(inputs) + output = adapter(inputs) + assert output.shape == (5, 8) + + def test_ffn_adapter(self) -> None: + adapter = WhiteningAdapter(adapter_type="ffn", emb_dim=16, proj_dim=8, hidden_units=4) + inputs = torch.randn(10, 16) + adapter.init_from_embeds(inputs) + output = adapter(inputs) + assert output.shape == (10, 8) + + def test_freeze_adapter(self) -> None: + adapter = WhiteningAdapter(emb_dim=16, proj_dim=8) + inputs = torch.randn(10, 16) + adapter.init_from_embeds(inputs) + assert not adapter.weight.requires_grad + assert not adapter.bias.requires_grad + + +class TestRQVAE: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + def test_forward_shapes(self) -> None: + model = RQVAE( + input_dim=16, + codebook_dim=8, + hidden_dims=[8], + codebook_sizes=[4, 4], + adapter_proj_dim=8, + ) + data = torch.randn(20, 16) + model.init_codebooks(data) + sem_ids, loss = model(data) + assert len(sem_ids) == 20 + assert all(len(sid) == 2 for sid in sem_ids) + assert loss.dim() == 0 + + def test_loss_is_finite(self) -> None: + model = RQVAE( + input_dim=16, + codebook_dim=8, + hidden_dims=[8], + codebook_sizes=[4, 4], + adapter_proj_dim=8, + ) + data = torch.randn(20, 16) + model.init_codebooks(data) + _, loss = model(data) + assert torch.isfinite(loss) + + def test_init_codebooks_sets_flag(self) -> None: + model = RQVAE( + input_dim=16, + codebook_dim=8, + hidden_dims=[8], + codebook_sizes=[4, 4], + adapter_proj_dim=8, + ) + data = torch.randn(50, 16) + model.init_codebooks(data) + for layer in model.codebooks: + assert layer.kmeans_initted_ + + def test_sem_ids_are_tuples_of_ints(self) -> None: + model = RQVAE( + input_dim=16, + codebook_dim=8, + hidden_dims=[8], + codebook_sizes=[4, 4, 4], + adapter_proj_dim=8, + ) + data = torch.randn(20, 16) + model.init_codebooks(data) + sem_ids, _ = model(data) + for sid in sem_ids: + assert isinstance(sid, tuple) + assert all(isinstance(c, int) for c in sid) + assert len(sid) == 3 diff --git a/tests/semantic/tokenizer/test_train.py b/tests/semantic/tokenizer/test_train.py new file mode 100644 index 00000000..cfdee28f --- /dev/null +++ b/tests/semantic/tokenizer/test_train.py @@ -0,0 +1,72 @@ +# Copyright 2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile + +import numpy as np +from pytorch_lightning import seed_everything + +from rectools.semantic.tokenizer.model import SIDTokenizer + + +class TestSIDTokenizerFit: + def setup_method(self) -> None: + seed_everything(42, workers=True) + + def test_fit_rkmeans(self) -> None: + rng = np.random.RandomState(42) + item_ids = list(range(100)) + embeddings = rng.randn(100, 16).astype(np.float32) + + with tempfile.TemporaryDirectory() as tmpdir: + tok = SIDTokenizer( + input_dim=16, + codebook_sizes=[4, 4], + quantizer="rkmeans", + device="cpu", + ) + tok.fit( + item_ids=item_ids, + embeddings=embeddings, + max_epochs=3, + patience=2, + batch_size=50, + save_dir=tmpdir, + ) + assert len(tok.id2sid) > 0 + + def test_fit_rqvae(self) -> None: + rng = np.random.RandomState(42) + item_ids = list(range(100)) + embeddings = rng.randn(100, 16).astype(np.float32) + + with tempfile.TemporaryDirectory() as tmpdir: + tok = SIDTokenizer( + input_dim=16, + codebook_sizes=[4, 4], + codebook_dim=8, + hidden_dims=[8], + quantizer="rqvae", + adapter_proj_dim=8, + device="cpu", + ) + tok.fit( + item_ids=item_ids, + embeddings=embeddings, + max_epochs=3, + patience=2, + batch_size=50, + save_dir=tmpdir, + ) + assert len(tok.id2sid) > 0