Skip to content

Commit

Permalink
replaced kmedoids implementation from scikit-learn-extra to kmedoids
Browse files Browse the repository at this point in the history
  • Loading branch information
GuyTeichman committed Aug 21, 2024
1 parent 2c6c95b commit 9151af5
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ pyyaml>=6.0
UpSetPlot>=0.9.0
matplotlib-venn>=1.1.1
scipy>=1.14.0
scikit-learn-extra>=0.3.0
pairwisedist>=1.3.1
requests>=2.24.0
graphviz>=0.20.1
Expand All @@ -27,5 +26,6 @@ tenacity>=8.2.3
mslex>=1.1.0
nest-asyncio>=1.6.0
fastcluster>=1.2.6
kmedoids>=0.5.1
polars[async,numpy,pyarrow,matplotlib,pandas]>=1.5.0
pandas[performance,parquet]
12 changes: 5 additions & 7 deletions rnalysis/utils/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@
from sklearn.decomposition import PCA
from sklearn.metrics import pairwise_distances, silhouette_score, calinski_harabasz_score, davies_bouldin_score
from sklearn.utils import parallel_backend as sklearn_parallel_backend
from sklearn_extra.cluster import KMedoids
from tqdm.auto import tqdm

from kmedoids import KMedoids
from rnalysis.utils import generic, parsing, validation

try:
Expand Down Expand Up @@ -345,18 +344,17 @@ class KMedoidsIter:
'medoid_indices_': "the clustering solution's medoid indices",
'labels_': "the clustering solution's point labels"}

def __init__(self, n_clusters: int, metric: str = 'euclidean', init: str = 'k-medoids++', max_iter: int = 300,
def __init__(self, n_clusters: int, metric: str = 'euclidean', max_iter: int = 300,
n_init: int = 10, random_state: int = None):
assert isinstance(n_init, int), f"'n_init' must be an integer, is {type(n_init)} instead."
assert isinstance(metric, str), f"'metric' must be a string, is {type(metric)} instead."
assert n_init > 0, f"'n_init' must be a positive integer. Input {n_init} is invalid. "
self.n_clusters = n_clusters
self.metric = metric
self.n_init = n_init
self.init = init
self.max_iter = max_iter
self.random_state = random_state
self.clusterer = KMedoids(n_clusters=self.n_clusters, metric=self.metric, init=self.init,
self.clusterer = KMedoids(n_clusters=self.n_clusters, metric=self.metric,
max_iter=self.max_iter, random_state=random_state)
self.inertia_ = None
self.cluster_centers_ = None
Expand All @@ -375,10 +373,10 @@ def fit(self, x):
for i in range(self.n_init):
if self.random_state is not None:
clusterers.append(
KMedoids(n_clusters=self.n_clusters, metric=self.metric, init=self.init, max_iter=self.max_iter,
KMedoids(n_clusters=self.n_clusters, metric=self.metric, max_iter=self.max_iter,
random_state=self.random_state + i).fit(x))
else:
clusterers.append(KMedoids(n_clusters=self.n_clusters, metric=self.metric, init=self.init,
clusterers.append(KMedoids(n_clusters=self.n_clusters, metric=self.metric,
max_iter=self.max_iter).fit(x))
inertias[i] = clusterers[i].inertia_
best_clusterer = clusterers[int(np.argmax(inertias))]
Expand Down
15 changes: 7 additions & 8 deletions tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import pytest

from rnalysis.utils.clustering import *
from rnalysis.utils.io import load_table


@pytest.fixture(scope='session')
def basic_counted_df():
return pl.read_csv('tests/test_files/counted.csv').drop(cs.first())


@pytest.fixture
def valid_clustering_solutions():
return [np.array([[1, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 1, 1, 0, 0, 0], [0, 0, 0, 0, 0, 1, 1, 1]]),
Expand All @@ -31,18 +32,16 @@ def invalid_clustering_solutions():


def test_kmedoidsiter_api(basic_counted_df):
truth = KMedoids(3, max_iter=300, init='k-medoids++', random_state=42)
kmeds = KMedoidsIter(3, init='k-medoids++', max_iter=300, n_init=1, random_state=42)
truth = KMedoids(3, max_iter=300, random_state=42, metric='euclidean')
kmeds = KMedoidsIter(3, max_iter=300, n_init=1, random_state=42, metric='euclidean')
df = basic_counted_df
truth.fit(df)
kmeds.fit(df)
assert np.all(truth.cluster_centers_ == kmeds.cluster_centers_)
assert np.all(truth.inertia_ == kmeds.inertia_)

assert np.all(truth.predict(df) == kmeds.predict(df))
assert np.all(truth.fit_predict(df) == kmeds.fit_predict(df))

kmeds_rand = KMedoidsIter(3, init='k-medoids++', max_iter=300, n_init=3)
kmeds_rand = KMedoidsIter(3, max_iter=300, n_init=3)
kmeds_rand.fit(df)
kmeds_rand.predict(df)
kmeds_rand.fit_predict(df)
Expand All @@ -52,14 +51,14 @@ def test_kmedoidsiter_api(basic_counted_df):


def test_kmedoidsiter_iter(basic_counted_df):
kmeds = KMedoidsIter(3, init='k-medoids++', max_iter=300, n_init=5, random_state=0)
kmeds = KMedoidsIter(3, max_iter=300, n_init=5, random_state=0, metric='euclidean')
df = basic_counted_df
kmeds.fit(df)

inertias = []
clusterers = []
for i in range(5):
clusterers.append(KMedoids(3, max_iter=300, init='k-medoids++', random_state=0).fit(df))
clusterers.append(KMedoids(3, max_iter=300, random_state=0, metric='euclidean').fit(df))
inertias.append(clusterers[i].inertia_)
truth_inertia = max(inertias)
truth_kmeds = clusterers[np.argmax(inertias)]
Expand Down

0 comments on commit 9151af5

Please sign in to comment.