Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/api/initializers/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,6 @@ Enumeration types that define available strategies for initialization.
:nosignatures:
:template: class.rst

ClusterMatchStrategy
MatchingMethod
ScoringMethod
EstimationStrategy
54 changes: 7 additions & 47 deletions rework_pysatl_mpest/initializers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,64 +9,24 @@

initializers provide good starting points for EM algorithm and other optimization
methods, helping to avoid poor local optima and improving convergence.

**Usage Example**

.. code-block:: python

>>> from rework_pysatl_mpest import Exponential
>>> import numpy as np
>>> from sklearn.cluster import KMeans
>>> from rework_pysatl_mpest.initializers import ClusterizeInitializer
>>> from rework_pysatl_mpest.initializers import ClusterMatchStrategy, EstimationStrategy

>>> # Create initializer with KMeans clustering
>>> initializer_cluster = ClusterizeInitializer(
... is_accurate=True,
... is_soft=False,
... clusterizer=KMeans(n_clusters=3)
... )

>>> # Create distribution models to initialize
>>> distributions = [Exponential(loc=0.0, rate=0.1),
>>>Exponential(loc=5.0, rate=0.05), Exponential(loc=10.0, rate=0.01)]

>>> # Generate sample data
>>> X = np.linspace(0.01, 25.0, 300)

>>> # Perform initialization
>>> mixture_model = initializer_cluster.perform(
... X=X,
... dists=distributions,
... cluster_match_strategy=ClusterMatchStrategy.AKAIKE,
... estimation_strategies=[EstimationStrategy.QFUNCTION] * len(distributions)
... )

>>> # The mixture model is now initialized with estimated parameters
>>> print(f"Number of components: {len(mixture_model.components)}")
>>> print(f"Weights: {mixture_model.weights}")
"""

__author__ = "Viktor Khanukaev"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from ._estimation_strategies.q_function import q_function_strategy, q_function_strategy_exponential
from .cluster_match_strategy import (
match_clusters_for_models_akaike,
match_clusters_for_models_log_likelihood,
from .clusterize import (
ClusterizeInitializer,
EstimationStrategy,
MatchingMethod,
ScoringMethod,
)
from .clusterize_initializer import ClusterizeInitializer
from .initializer import Initializer
from .strategies import ClusterMatchStrategy, EstimationStrategy

__all__ = [
"ClusterMatchStrategy",
"ClusterizeInitializer",
"EstimationStrategy",
"Initializer",
"match_clusters_for_models_akaike",
"match_clusters_for_models_log_likelihood",
"q_function_strategy",
"q_function_strategy_exponential",
"MatchingMethod",
"ScoringMethod",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
"""
A subpackage containing estimation strategies for initialization.
"""

__author__ = "Viktor Khanukaev"
__copyright__ = "Copyright (c) 2025 PySATL project"
__license__ = "SPDX-License-Identifier: MIT"

from .q_function import q_function_strategy, q_function_strategy_exponential

__all__ = [
"q_function_strategy",
"q_function_strategy_exponential",
]
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,10 @@

import numpy as np

from rework_pysatl_mpest.distributions.continuous_dist import ContinuousDistribution
from rework_pysatl_mpest.distributions.exponential import Exponential
from rework_pysatl_mpest.optimizers import Optimizer
from ...distributions import ContinuousDistribution, Exponential
from ...optimizers import Optimizer

NUMERICAL_TOLERANCE = 0.33
NUMERICAL_TOLERANCE = 1e-6


@singledispatch
Expand Down Expand Up @@ -164,7 +163,7 @@ def q_function_strategy_exponential(
>>> print(f"Estimated loc: {params['loc']:.3f}, rate: {params['rate']:.3f}")
"""

new_params = {}
new_params: dict = {}
N_j = np.sum(H_j).item()

if np.any(H_j > NUMERICAL_TOLERANCE):
Expand Down
276 changes: 0 additions & 276 deletions rework_pysatl_mpest/initializers/cluster_match_strategy.py

This file was deleted.

Loading