


SS_KMEANS K-means clustering.
SPIKES = SS_KMEANS(SPIKES) takes and returns a spike-sorting object SPIKES.
It k-means clusters the (M x N) matrix SPIKES.WAVEFORMS and stores the
resulting group assignments in SPIKES.OVERCLUSTER.ASSIGNS, the cluster
centroids in SPIKES.OVERCLUSTER.CENTROIDS, and the mean squared error in
SPIKES.OVERCLUSTER.MSE. The W, B, and T matrices (within-cluster, between-
cluster, and total) covariance matrices are in SPIKES.OVERCLUSTER.W, etc.
K-means algorithm is an EM-like algorithm that finds K cluster centers in
N-dimensional space and assigns each of the M data vectors to one of these
K points. The process is iterative: new cluster centers are calculated as
the mean of their previously assigned vectors and vectors are then each
reassigned to their closest cluster center. This finds a local minimum for
the mean squared distance from each point to its cluster center; this is
also a (local) MLE for the model of K isotropic gaussian clusters.
This method is highly outlier sensitive; MLE is not robust to the addition
of a few waveforms that are not like the others and will shift the 'true'
cluster centers (or add new ones) to account for these points. See
SS_OUTLIERS for one solution.
The algorithm used here speeds convergence by first solving for 2 means,
then using these means (slightly jittered) as starting points for a 4-means
solution. This continues for log2(K) steps until K-means have been found.
SPIKES = SS_KMEANS(SPIKES, OPTIONS) allows specification of clustering
parameters. OPTIONS is a structure with some/all of the following fields
defined. (Any OPTIONS fields left undefined (or all fields if no OPTIONS
structure is passed in) uses its default value.)
OPTIONS.DIVISIONS (default: round(log2(M/500)), clipped to be between
4 and 8) sets the desired number of clusters to 2^DIVISIONS. The
actual number of clusters may be slightly more/less than this #.
OPTIONS.REPS (default: 1) specifies the number of runs of the full
k-means solution. The function will return the assignments that
resulted in the minimum MSE.
OPTIONS.REASSIGN_CONVERGE (default: 0) defines a convergence condition
by specifying the max number of vectors allowed to be reassigned
in an EM step. If <= this number of vectors is reassigned, the
this condition is met.
OPTIONS.MSE_CONVERGE (default: 0) defines a second convergence condition.
If the fractional change in mean squared error from one iteration
to the next is smaller than this value, this condition is met.
NOTE: Iteration stops when either of the convergence conditions is met.
References:
Duda RO et al (2001). _Pattern Classification_, Wiley-Interscience

0001 function spikes = ss_kmeans(spikes, options) 0002 % SS_KMEANS K-means clustering. 0003 % SPIKES = SS_KMEANS(SPIKES) takes and returns a spike-sorting object SPIKES. 0004 % It k-means clusters the (M x N) matrix SPIKES.WAVEFORMS and stores the 0005 % resulting group assignments in SPIKES.OVERCLUSTER.ASSIGNS, the cluster 0006 % centroids in SPIKES.OVERCLUSTER.CENTROIDS, and the mean squared error in 0007 % SPIKES.OVERCLUSTER.MSE. The W, B, and T matrices (within-cluster, between- 0008 % cluster, and total) covariance matrices are in SPIKES.OVERCLUSTER.W, etc. 0009 % 0010 % K-means algorithm is an EM-like algorithm that finds K cluster centers in 0011 % N-dimensional space and assigns each of the M data vectors to one of these 0012 % K points. The process is iterative: new cluster centers are calculated as 0013 % the mean of their previously assigned vectors and vectors are then each 0014 % reassigned to their closest cluster center. This finds a local minimum for 0015 % the mean squared distance from each point to its cluster center; this is 0016 % also a (local) MLE for the model of K isotropic gaussian clusters. 0017 % 0018 % This method is highly outlier sensitive; MLE is not robust to the addition 0019 % of a few waveforms that are not like the others and will shift the 'true' 0020 % cluster centers (or add new ones) to account for these points. See 0021 % SS_OUTLIERS for one solution. 0022 % 0023 % The algorithm used here speeds convergence by first solving for 2 means, 0024 % then using these means (slightly jittered) as starting points for a 4-means 0025 % solution. This continues for log2(K) steps until K-means have been found. 0026 % 0027 % SPIKES = SS_KMEANS(SPIKES, OPTIONS) allows specification of clustering 0028 % parameters. OPTIONS is a structure with some/all of the following fields 0029 % defined. (Any OPTIONS fields left undefined (or all fields if no OPTIONS 0030 % structure is passed in) uses its default value.) 0031 % 0032 % OPTIONS.DIVISIONS (default: round(log2(M/500)), clipped to be between 0033 % 4 and 8) sets the desired number of clusters to 2^DIVISIONS. The 0034 % actual number of clusters may be slightly more/less than this #. 0035 % OPTIONS.REPS (default: 1) specifies the number of runs of the full 0036 % k-means solution. The function will return the assignments that 0037 % resulted in the minimum MSE. 0038 % OPTIONS.REASSIGN_CONVERGE (default: 0) defines a convergence condition 0039 % by specifying the max number of vectors allowed to be reassigned 0040 % in an EM step. If <= this number of vectors is reassigned, the 0041 % this condition is met. 0042 % OPTIONS.MSE_CONVERGE (default: 0) defines a second convergence condition. 0043 % If the fractional change in mean squared error from one iteration 0044 % to the next is smaller than this value, this condition is met. 0045 % 0046 % NOTE: Iteration stops when either of the convergence conditions is met. 0047 % 0048 % References: 0049 % Duda RO et al (2001). _Pattern Classification_, Wiley-Interscience 0050 0051 % Last Modified By: sbm on Sun Aug 13 02:28:36 2006 0052 0053 % Undocumented options: 0054 % OPTIONS.PROGRESS (default: 1) determines whether the progress 0055 % bar is displayed during the clustering. 0056 % OPTIONS.SPLITMORE (default: 1) determines whether to do extra cluster 0057 % splitting to try for smaller clusters 0058 0059 starttime = clock; 0060 0061 % save the random number seeds before we do anything, in case we want to 0062 % replicate this run exactly ... 0063 randnstate = randn('state'); randstate = rand('state'); 0064 0065 %%%%%%%%%% ARGUMENT CHECKING 0066 if (~isfield(spikes, 'waveforms') || (size(spikes.waveforms, 1) < 1)) 0067 error('SS:waveforms_undefined', 'The SS object does not contain any waveforms!'); 0068 end 0069 0070 %%%%%%%%%% CONSTANTS 0071 target_clustersize = 500; 0072 waves = spikes.waveforms; % ref w/o struct is better for R13 acceleration 0073 [M,N] = size(waves); 0074 jitter = meandist_estim(waves) / 100 / N; % heuristic 0075 0076 %%%%%%%%%% DEFAULTS 0077 opts.divisions = round(log2(M / target_clustersize)); % power of 2 that gives closest to target_clustersize, but in [4..7] 0078 opts.divisions = max(min(opts.divisions, 8), 4); % restrict to 16-256 clusters; heuristic. 0079 opts.memorylimit = M; % all spikes at once 0080 opts.memorylimit = 300; % 0081 opts.reps = 1; % just one repetition 0082 opts.reassign_converge = 0; % the last few stubborn ones are probably borderline cases anyway ... 0083 opts.reassign_rough = round(0.005*M); % (no need at all to get full convergence for intermediate runs) 0084 opts.mse_converge = 0; % ... and by default, we don't use the mse convergence criterion 0085 opts.progress = 1; 0086 opts.splitmore = 1; 0087 if (nargin > 1) 0088 supplied = lower(fieldnames(options)); % which options did the user specify? 0089 for op = 1:length(supplied) % copy those over the defaults 0090 opts.(supplied{op}) = options.(supplied{op}); % this is the preferred syntax as of R13 -- 0091 end 0092 end 0093 0094 %%%%%%%%%% CLUSTERING 0095 normsq = sum(waves.^2, 2); 0096 assigns = ones(M, opts.reps); 0097 mse = Inf * ones(1, opts.reps); 0098 clear fast_kmeans; 0099 0100 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0101 % IMPORTANT NOTE: For BLAS efficiency reasons, % 0102 % the waveforms & centroids matrices are % 0103 % transposed from their ordinary orientation % 0104 % in the following section of code. % 0105 % E.g., waveforms is (D samples x N waveforms)% 0106 waves = waves'; % 0107 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0108 for rep = 1:opts.reps % TOTAL # REPETITIONS 0109 0110 centroid = mean(waves, 2); % always start here 0111 itercounter = zeros(1,opts.divisions); 0112 for iter = 1:opts.divisions % # K-MEANS SPLITS 0113 itercounter(iter) = 0; 0114 oldassigns = zeros(M, 1); oldmse = Inf; 0115 assign_converge = 0; mse_converge = 0; 0116 0117 if (iter == opts.divisions) 0118 progress = opts.progress; reassign_criterion = opts.reassign_converge; 0119 else 0120 progress = 0; reassign_criterion = opts.reassign_rough; 0121 end 0122 0123 if ((iter==opts.divisions) && opts.splitmore) % do some extra splitting on clusters that look too big 0124 toobig = find(clustersizes > 2*target_clustersize); 0125 splitbig = centroid(:,toobig) + randn(length(toobig),size(centroid,1))'; 0126 centroid = [centroid splitbig]; 0127 end 0128 0129 centroid = [centroid centroid] + jitter * randn(2*size(centroid, 2),size(centroid,1))'; % split & jitter 0130 clustersizes = ones(1,size(centroid,2)); % temporary placeholder to indicate no empty clusters at first 0131 0132 if (opts.progress) 0133 progressBar(0, 1, sprintf('Calculating %d means.', size(centroid,2))); % crude ... 0134 end 0135 0136 while (~(assign_converge || mse_converge)) % convergence? 0137 %%%%% Clean out empty clusters 0138 centroid(:,(clustersizes == 0)) = []; 0139 0140 %%%%% EM STEP 0141 [assignlist, bestdists, clustersizes, centroid] = ... 0142 fast_kmeans_step(waves, centroid, normsq, opts.memorylimit); 0143 0144 %%%%% Compute convergence info 0145 mse(rep) = mean(bestdists); 0146 mse_converge = ((1 - (mse(rep)/oldmse)) <= opts.mse_converge); % fractional change 0147 oldmse = mse(rep); 0148 0149 changed_assigns = sum(assignlist ~= oldassigns); 0150 assign_converge = (changed_assigns <= reassign_criterion); % num waveforms reassigned 0151 if (progress) 0152 progressBar(((M - changed_assigns)/M).^10, 5); % rough ... 0153 end 0154 oldassigns = assignlist; 0155 itercounter(iter) = itercounter(iter) + 1; 0156 end 0157 end 0158 0159 % Finally, reassign spikes in singleton clusters to next best ... 0160 junkclust = find((clustersizes <= 1)); junkspikes = ismember(assignlist, junkclust); 0161 centroid(:,junkclust) = Inf; % take these out of the running & recompute distances 0162 if (any(junkspikes)) 0163 dist_to_rest = pairdist(waves(:,junkspikes)', centroid', 'nosqrt'); 0164 [bestdists(junkspikes), assignlist(junkspikes)] = min(dist_to_rest, [], 2); 0165 end 0166 mse(rep) = mean(bestdists); 0167 0168 assigns(:,rep) = assignlist; 0169 end 0170 0171 spikes.overcluster.iteration_count = itercounter; 0172 spikes.overcluster.randn_state = randnstate; 0173 spikes.overcluster.rand_state = randstate; 0174 0175 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0176 % Waves & centroids now return to their usual % 0177 % orientation. E.g., waveforms is again % 0178 % N waveforms x D samples % 0179 waves = waves'; % 0180 centroid = centroid'; % 0181 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0182 0183 0184 % Finish up by selecting the lowest mse over repetitions. 0185 [bestmse, choice] = min(mse); 0186 spikes.overcluster.assigns = sortAssignments(assigns(:,choice)); 0187 spikes.overcluster.mse = bestmse; 0188 spikes.overcluster.sqerr = bestdists; 0189 0190 % We also save the winning cluster centers as a convenience 0191 numclusts = max(spikes.overcluster.assigns); 0192 spikes.overcluster.centroids = zeros(numclusts, N); 0193 for clust = 1:numclusts 0194 members = find(spikes.overcluster.assigns == clust); 0195 spikes.overcluster.centroids(clust,:) = mean(waves(members,:), 1); 0196 end 0197 0198 % And W, B, T matrices -- easy since T & B are fast to compute and T = W + B) 0199 spikes.overcluster.T = cov(waves); % normalize everything by M-1, not M 0200 spikes.overcluster.B = cov(spikes.overcluster.centroids(spikes.overcluster.assigns, :)); 0201 spikes.overcluster.W = spikes.overcluster.T - spikes.overcluster.B; 0202 0203 % Finally, assign colors to the spike clusters here for consistency ... 0204 cmap = jetm(numclusts); 0205 spikes.overcluster.colors = cmap(randperm(numclusts),:); 0206 0207 if (opts.progress), progressBar(1.0, 1, ''); end 0208 spikes.tictoc.kmeans = etime(clock, starttime);