Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 31 additions & 16 deletions rehline/_mf_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings

import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator
from sklearn.exceptions import ConvergenceWarning
from sklearn.utils.validation import _check_sample_weight
Expand Down Expand Up @@ -214,27 +213,29 @@ def __init__(
raise ValueError("; ".join(errors))

# parameter initialization
## -----------------------------basic perameters-----------------------------
## -----------------------------basic parameters-----------------------------
self.n_users = n_users
self.n_items = n_items
self.loss = loss
self.constraint_user = constraint_user if constraint_user is not None else []
self.constraint_item = constraint_item if constraint_item is not None else []
self.biased = biased
## -----------------------------hyper perameters-----------------------------
## -----------------------------hyper parameters-----------------------------
self.rank = rank
self.C = C
self.rho = rho
## --------------------------coefficient perameters--------------------------
## --------------------------coefficient parameters--------------------------
self.init_mean = init_mean
self.init_sd = init_sd
self.random_state = random_state
if self.random_state:
np.random.seed(random_state)
self.P = np.random.normal(loc=init_mean, scale=init_sd, size=(n_users, rank))
self.Q = np.random.normal(loc=init_mean, scale=init_sd, size=(n_items, rank))
self.bu = np.zeros(n_users) if self.biased else None
self.bi = np.zeros(n_items) if self.biased else None
if random_state is not None:
self.rng = np.random.default_rng(random_state)
else:
self.rng = np.random.default_rng()
self.P = None
self.Q = None
self.bu = None
self.bi = None
## ----------------------------fitting parameters----------------------------
self.max_iter_CD = max_iter_CD
self.tol_CD = tol_CD
Expand Down Expand Up @@ -271,11 +272,19 @@ def fit(self, X, y, sample_weight=None):
self.history = np.nan * np.zeros((self.max_iter_CD + 1, 2))
self.sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

X_df = pd.DataFrame(X, columns=["user", "item"])
uidx_map = X_df.groupby("user").indices
iidx_map = X_df.groupby("item").indices
self.Iu = [uidx_map.get(u, np.array([], dtype=int)) for u in range(self.n_users)]
self.Ui = [iidx_map.get(i, np.array([], dtype=int)) for i in range(self.n_items)]
sort_idx_users = np.argsort(X[:, 0], kind='stable')
sorted_users = X[sort_idx_users, 0]
counts = np.unique(sorted_users, return_counts=True)[1]
self.Iu = [np.array([], dtype=int)] * self.n_users
for u, idxs in zip(sorted_users[np.cumsum(counts) - counts], np.split(sort_idx_users, np.cumsum(counts)[:-1])):
self.Iu[u] = idxs

sort_idx_items = np.argsort(X[:, 1], kind='stable')
sorted_items = X[sort_idx_items, 1]
counts = np.unique(sorted_items, return_counts=True)[1]
self.Ui = [np.array([], dtype=int)] * self.n_items
for i, idxs in zip(sorted_items[np.cumsum(counts) - counts], np.split(sort_idx_items, np.cumsum(counts)[:-1])):
self.Ui[i] = idxs

C_user = self.C * self.n_users / (self.rho) / 2
C_item = self.C * self.n_items / (1 - self.rho) / 2
Expand All @@ -289,6 +298,12 @@ def fit(self, X, y, sample_weight=None):
)
)

# Initialization
self.P = self.rng.normal(loc=self.init_mean, scale=self.init_sd, size=(self.n_users, self.rank))
self.Q = self.rng.normal(loc=self.init_mean, scale=self.init_sd, size=(self.n_items, self.rank))
self.bu = np.zeros(self.n_users) if self.biased else None
self.bi = np.zeros(self.n_items) if self.biased else None

# CD algorithm
self.history[0] = self.obj(X, y)
for iter_idx in range(self.max_iter_CD):
Expand Down Expand Up @@ -435,7 +450,7 @@ def fit(self, X, y, sample_weight=None):
obj = f"{self.history[iter_idx + 1][1]:.6f}"
print(f"{iter_idx + 1:<12} {mean_loss:<20} {obj:<20}")

if obj_diff < self.tol_CD:
if abs(obj_diff) < self.tol_CD:
break

return self
Expand Down
Loading