


SS_KMEANS K-means clustering.
SPIKES = SS_KMEANS(SPIKES) takes and returns a spike-sorting object SPIKES.
It k-means clusters the (M x N) matrix SPIKES.WAVEFORMS and stores the
resulting group assignments in SPIKES.OVERCLUSTER.ASSIGNS, the cluster
centroids in SPIKES.OVERCLUSTER.CENTROIDS, and the mean squared error in
SPIKES.OVERCLUSTER.MSE. The W, B, and T matrices (within-cluster, between-
cluster, and total) covariance matrices are in SPIKES.OVERCLUSTER.W, etc.
K-means algorithm is an EM-like algorithm that finds K cluster centers in
N-dimensional space and assigns each of the M data vectors to one of these
K points. The process is iterative: new cluster centers are calculated as
the mean of their previously assigned vectors and vectors are then each
reassigned to their closest cluster center. This finds a local minimum for
the mean squared distance from each point to its cluster center; this is
also a (local) MLE for the model of K isotropic gaussian clusters.
This method is highly outlier sensitive; MLE is not robust to the addition
of a few waveforms that are not like the others and will shift the 'true'
cluster centers (or add new ones) to account for these points. See
SS_OUTLIERS for one solution.
The algorithm used here speeds convergence by first solving for 2 means,
then using these means (slightly jittered) as starting points for a 4-means
solution. This continues for log2(K) steps until K-means have been found.
SPIKES = SS_KMEANS(SPIKES, OPTIONS) allows specification of clustering
parameters. OPTIONS is a structure with some/all of the following fields
defined. (Any OPTIONS fields left undefined (or all fields if no OPTIONS
structure is passed in) uses its default value.)
OPTIONS.DIVISIONS (default: round(log2(M/500)), clipped to be between
4 and 8) sets the desired number of clusters to 2^DIVISIONS. The
actual number of clusters may be slightly more/less than this #.
OPTIONS.REPS (default: 1) specifies the number of runs of the full
k-means solution. The function will return the assignments that
resulted in the minimum MSE.
OPTIONS.REASSIGN_CONVERGE (default: 0) defines a convergence condition
by specifying the max number of vectors allowed to be reassigned
in an EM step. If <= this number of vectors is reassigned, the
this condition is met.
OPTIONS.MSE_CONVERGE (default: 0) defines a second convergence condition.
If the fractional change in mean squared error from one iteration
to the next is smaller than this value, this condition is met.
NOTE: Iteration stops when either of the convergence conditions is met.
References:
Duda RO et al (2001). _Pattern Classification_, Wiley-Interscience

0001 function spikes = ss_kmeans(spikes, options) 0002 % SS_KMEANS K-means clustering. 0003 % SPIKES = SS_KMEANS(SPIKES) takes and returns a spike-sorting object SPIKES. 0004 % It k-means clusters the (M x N) matrix SPIKES.WAVEFORMS and stores the 0005 % resulting group assignments in SPIKES.OVERCLUSTER.ASSIGNS, the cluster 0006 % centroids in SPIKES.OVERCLUSTER.CENTROIDS, and the mean squared error in 0007 % SPIKES.OVERCLUSTER.MSE. The W, B, and T matrices (within-cluster, between- 0008 % cluster, and total) covariance matrices are in SPIKES.OVERCLUSTER.W, etc. 0009 % 0010 % K-means algorithm is an EM-like algorithm that finds K cluster centers in 0011 % N-dimensional space and assigns each of the M data vectors to one of these 0012 % K points. The process is iterative: new cluster centers are calculated as 0013 % the mean of their previously assigned vectors and vectors are then each 0014 % reassigned to their closest cluster center. This finds a local minimum for 0015 % the mean squared distance from each point to its cluster center; this is 0016 % also a (local) MLE for the model of K isotropic gaussian clusters. 0017 % 0018 % This method is highly outlier sensitive; MLE is not robust to the addition 0019 % of a few waveforms that are not like the others and will shift the 'true' 0020 % cluster centers (or add new ones) to account for these points. See 0021 % SS_OUTLIERS for one solution. 0022 % 0023 % The algorithm used here speeds convergence by first solving for 2 means, 0024 % then using these means (slightly jittered) as starting points for a 4-means 0025 % solution. This continues for log2(K) steps until K-means have been found. 0026 % 0027 % SPIKES = SS_KMEANS(SPIKES, OPTIONS) allows specification of clustering 0028 % parameters. OPTIONS is a structure with some/all of the following fields 0029 % defined. (Any OPTIONS fields left undefined (or all fields if no OPTIONS 0030 % structure is passed in) uses its default value.) 0031 % 0032 % OPTIONS.DIVISIONS (default: round(log2(M/500)), clipped to be between 0033 % 4 and 8) sets the desired number of clusters to 2^DIVISIONS. The 0034 % actual number of clusters may be slightly more/less than this #. 0035 % OPTIONS.REPS (default: 1) specifies the number of runs of the full 0036 % k-means solution. The function will return the assignments that 0037 % resulted in the minimum MSE. 0038 % OPTIONS.REASSIGN_CONVERGE (default: 0) defines a convergence condition 0039 % by specifying the max number of vectors allowed to be reassigned 0040 % in an EM step. If <= this number of vectors is reassigned, the 0041 % this condition is met. 0042 % OPTIONS.MSE_CONVERGE (default: 0) defines a second convergence condition. 0043 % If the fractional change in mean squared error from one iteration 0044 % to the next is smaller than this value, this condition is met. 0045 % 0046 % NOTE: Iteration stops when either of the convergence conditions is met. 0047 % 0048 % References: 0049 % Duda RO et al (2001). _Pattern Classification_, Wiley-Interscience 0050 0051 % Last Modified By: sbm on Sun Aug 13 02:28:36 2006 0052 0053 % Undocumented options: 0054 % OPTIONS.PROGRESS (default: 1) determines whether the progress 0055 % bar is displayed during the clustering. 0056 % OPTIONS.SPLITMORE (default: 1) determines whether to do extra cluster 0057 % splitting to try for smaller clusters 0058 0059 starttime = clock; 0060 0061 % save the random number seeds before we do anything, in case we want to 0062 % replicate this run exactly ... 0063 randnstate = randn('state'); randstate = rand('state'); 0064 0065 %%%%%%%%%% ARGUMENT CHECKING 0066 if (~isfield(spikes, 'waveforms') | (size(spikes.waveforms, 1) < 1)) 0067 error('SS:waveforms_undefined', 'The SS object does not contain any waveforms!'); 0068 end 0069 0070 %%%%%%%%%% CONSTANTS 0071 target_clustersize = 500; 0072 waves = spikes.waveforms; % ref w/o struct is better for R13 acceleration 0073 [M,N] = size(waves); 0074 jitter = meandist_estim(waves) / 100 / N; % heuristic 0075 0076 %%%%%%%%%% DEFAULTS 0077 opts.divisions = round(log2(M / target_clustersize)); % power of 2 that gives closest to target_clustersize, but in [4..7] 0078 opts.divisions = max(min(opts.divisions, 8), 4); % restrict to 16-256 clusters; heuristic. 0079 opts.memorylimit = M; % all spikes at once 0080 opts.memorylimit = 300; % 0081 opts.reps = 1; % just one repetition 0082 opts.reassign_converge = 0; % the last few stubborn ones are probably borderline cases anyway ... 0083 opts.reassign_rough = round(0.005*M); % (no need at all to get full convergence for intermediate runs) 0084 opts.mse_converge = 0; % ... and by default, we don't use the mse convergence criterion 0085 opts.progress = 1; 0086 opts.splitmore = 1; 0087 if (nargin > 1) 0088 supplied = lower(fieldnames(options)); % which options did the user specify? 0089 for op = 1:length(supplied) % copy those over the defaults 0090 opts.(supplied{op}) = options.(supplied{op}); % this is the preferred syntax as of R13 -- 0091 end 0092 end 0093 0094 %%%%%%%%%% CLUSTERING 0095 normsq = sum(waves.^2, 2); 0096 assigns = ones(M, opts.reps); 0097 mse = Inf * ones(1, opts.reps); 0098 clear fast_kmeans; 0099 0100 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0101 % IMPORTANT NOTE: For BLAS efficiency reasons, % 0102 % the waveforms & centroids matrices are % 0103 % transposed from their ordinary orientation % 0104 % in the following section of code. % 0105 % E.g., waveforms is (D samples x N waveforms)% 0106 waves = waves'; % 0107 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0108 for rep = 1:opts.reps % TOTAL # REPETITIONS 0109 0110 centroid = mean(waves, 2); % always start here 0111 for iter = 1:opts.divisions % # K-MEANS SPLITS 0112 itercounter(iter) = 0; 0113 oldassigns = zeros(M, 1); oldmse = Inf; 0114 assign_converge = 0; mse_converge = 0; 0115 0116 if (iter == opts.divisions), progress = opts.progress; reassign_criterion = opts.reassign_converge; 0117 else, progress = 0; reassign_criterion = opts.reassign_rough; 0118 end 0119 0120 if ((iter==opts.divisions) && opts.splitmore) % do some extra splitting on clusters that look too big 0121 toobig = find(clustersizes > 2*target_clustersize); 0122 splitbig = centroid(:,toobig) + randn(length(toobig),size(centroid,1))'; 0123 centroid = [centroid splitbig]; 0124 end 0125 0126 centroid = [centroid centroid] + jitter * randn(2*size(centroid, 2),size(centroid,1))'; % split & jitter 0127 clustersizes = ones(1,size(centroid,2)); % temporary placeholder to indicate no empty clusters at first 0128 0129 if (opts.progress) 0130 progressBar(0, 1, sprintf('Calculating %d means.', size(centroid,2))); % crude ... 0131 end 0132 0133 while (~(assign_converge | mse_converge)) % convergence? 0134 %%%%% Clean out empty clusters 0135 centroid(:,[clustersizes == 0]) = []; 0136 0137 %%%%% EM STEP 0138 [assignlist, bestdists, clustersizes, centroid] = ... 0139 fast_kmeans_step(waves, centroid, normsq, opts.memorylimit); 0140 0141 %%%%% Compute convergence info 0142 mse(rep) = mean(bestdists); 0143 mse_converge = ((1 - (mse(rep)/oldmse)) <= opts.mse_converge); % fractional change 0144 oldmse = mse(rep); 0145 0146 changed_assigns = sum(assignlist ~= oldassigns); 0147 assign_converge = (changed_assigns <= reassign_criterion); % num waveforms reassigned 0148 if (progress) 0149 progressBar(((M - changed_assigns)/M).^10, 5); % rough ... 0150 end 0151 oldassigns = assignlist; 0152 itercounter(iter) = itercounter(iter) + 1; 0153 end 0154 end 0155 0156 % Finally, reassign spikes in singleton clusters to next best ... 0157 junkclust = find([clustersizes <= 1]); junkspikes = ismember(assignlist, junkclust); 0158 centroid(:,junkclust) = Inf; % take these out of the running & recompute distances 0159 if (any(junkspikes)) 0160 dist_to_rest = pairdist(waves(:,junkspikes)', centroid', 'nosqrt'); 0161 [bestdists(junkspikes), assignlist(junkspikes)] = min(dist_to_rest, [], 2); 0162 end 0163 mse(rep) = mean(bestdists); 0164 0165 assigns(:,rep) = assignlist; 0166 end 0167 0168 spikes.overcluster.iteration_count = itercounter; 0169 spikes.overcluster.randn_state = randnstate; 0170 spikes.overcluster.rand_state = randnstate; 0171 0172 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0173 % Waves & centroids now return to their usual % 0174 % orientation. E.g., waveforms is again % 0175 % N waveforms x D samples % 0176 waves = waves'; % 0177 centroid = centroid'; % 0178 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 0179 0180 0181 % Finish up by selecting the lowest mse over repetitions. 0182 [bestmse, choice] = min(mse); 0183 spikes.overcluster.assigns = sortAssignments(assigns(:,choice)); 0184 spikes.overcluster.mse = bestmse; 0185 spikes.overcluster.sqerr = bestdists; 0186 0187 % We also save the winning cluster centers as a convenience 0188 numclusts = max(spikes.overcluster.assigns); 0189 spikes.overcluster.centroids = zeros(numclusts, N); 0190 for clust = 1:numclusts 0191 members = find(spikes.overcluster.assigns == clust); 0192 spikes.overcluster.centroids(clust,:) = mean(waves(members,:), 1); 0193 end 0194 0195 % And W, B, T matrices -- easy since T & B are fast to compute and T = W + B) 0196 spikes.overcluster.T = cov(waves); % normalize everything by M-1, not M 0197 spikes.overcluster.B = cov(spikes.overcluster.centroids(spikes.overcluster.assigns, :)); 0198 spikes.overcluster.W = spikes.overcluster.T - spikes.overcluster.B; 0199 0200 % Finally, assign colors to the spike clusters here for consistency ... 0201 cmap = jetm(numclusts); 0202 spikes.overcluster.colors = cmap(randperm(numclusts),:); 0203 0204 if (opts.progress), progressBar(1.0, 1, ''); end 0205 spikes.tictoc.kmeans = etime(clock, starttime);