Home > chronux_1_1 > spikesort > ss_kmeans.m

ss_kmeans

PURPOSE ^

SS_KMEANS K-means clustering.

SYNOPSIS ^

function spikes = ss_kmeans(spikes, options)

DESCRIPTION ^

 SS_KMEANS  K-means clustering.
     SPIKES = SS_KMEANS(SPIKES) takes and returns a spike-sorting object SPIKES.
     It k-means clusters the (M x N) matrix SPIKES.WAVEFORMS and stores the
     resulting group assignments in SPIKES.OVERCLUSTER.ASSIGNS, the cluster
     centroids in SPIKES.OVERCLUSTER.CENTROIDS, and the mean squared error in
     SPIKES.OVERCLUSTER.MSE.  The W, B, and T matrices (within-cluster, between-
     cluster, and total) covariance matrices are in SPIKES.OVERCLUSTER.W, etc.

     K-means algorithm is an EM-like algorithm that finds K cluster centers in
     N-dimensional space and assigns each of the M data vectors to one of these
     K points.  The process is iterative: new cluster centers are calculated as
     the mean of their previously assigned vectors and vectors are then each 
     reassigned to their closest cluster center.  This finds a local minimum for
     the mean squared distance from each point to its cluster center; this is
     also a (local) MLE for the model of K isotropic gaussian clusters.
 
     This method is highly outlier sensitive; MLE is not robust to the addition
     of a few waveforms that are not like the others and will shift the 'true'
     cluster centers (or add new ones) to account for these points.  See
     SS_OUTLIERS for one solution.
 
     The algorithm used here speeds convergence by first solving for 2 means,
     then using these means (slightly jittered) as starting points for a 4-means
     solution.  This continues for log2(K) steps until K-means have been found.

     SPIKES = SS_KMEANS(SPIKES, OPTIONS) allows specification of clustering
     parameters.  OPTIONS is a structure with some/all of the following fields
     defined.  (Any OPTIONS fields left undefined (or all fields if no OPTIONS
     structure is passed in) uses its default value.)

         OPTIONS.DIVISIONS (default: round(log2(M/500)), clipped to be between
               4 and 8) sets the desired number of clusters to 2^DIVISIONS.  The
               actual number of clusters may be slightly more/less than this #.
         OPTIONS.REPS (default: 1) specifies the number of runs of the full
               k-means solution.  The function will return the assignments that
               resulted in the minimum MSE.
         OPTIONS.REASSIGN_CONVERGE (default: 0) defines a convergence condition
               by specifying the max number of vectors allowed to be reassigned
               in an EM step.  If <= this number of vectors is reassigned, the
               this condition is met.
         OPTIONS.MSE_CONVERGE (default: 0) defines a second convergence condition.
               If the fractional change in mean squared error from one iteration
               to the next is smaller than this value, this condition is met.

     NOTE: Iteration stops when either of the convergence conditions is met.
  
 References:
     Duda RO et al (2001).  _Pattern Classification_, Wiley-Interscience

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function spikes = ss_kmeans(spikes, options)
0002 % SS_KMEANS  K-means clustering.
0003 %     SPIKES = SS_KMEANS(SPIKES) takes and returns a spike-sorting object SPIKES.
0004 %     It k-means clusters the (M x N) matrix SPIKES.WAVEFORMS and stores the
0005 %     resulting group assignments in SPIKES.OVERCLUSTER.ASSIGNS, the cluster
0006 %     centroids in SPIKES.OVERCLUSTER.CENTROIDS, and the mean squared error in
0007 %     SPIKES.OVERCLUSTER.MSE.  The W, B, and T matrices (within-cluster, between-
0008 %     cluster, and total) covariance matrices are in SPIKES.OVERCLUSTER.W, etc.
0009 %
0010 %     K-means algorithm is an EM-like algorithm that finds K cluster centers in
0011 %     N-dimensional space and assigns each of the M data vectors to one of these
0012 %     K points.  The process is iterative: new cluster centers are calculated as
0013 %     the mean of their previously assigned vectors and vectors are then each
0014 %     reassigned to their closest cluster center.  This finds a local minimum for
0015 %     the mean squared distance from each point to its cluster center; this is
0016 %     also a (local) MLE for the model of K isotropic gaussian clusters.
0017 %
0018 %     This method is highly outlier sensitive; MLE is not robust to the addition
0019 %     of a few waveforms that are not like the others and will shift the 'true'
0020 %     cluster centers (or add new ones) to account for these points.  See
0021 %     SS_OUTLIERS for one solution.
0022 %
0023 %     The algorithm used here speeds convergence by first solving for 2 means,
0024 %     then using these means (slightly jittered) as starting points for a 4-means
0025 %     solution.  This continues for log2(K) steps until K-means have been found.
0026 %
0027 %     SPIKES = SS_KMEANS(SPIKES, OPTIONS) allows specification of clustering
0028 %     parameters.  OPTIONS is a structure with some/all of the following fields
0029 %     defined.  (Any OPTIONS fields left undefined (or all fields if no OPTIONS
0030 %     structure is passed in) uses its default value.)
0031 %
0032 %         OPTIONS.DIVISIONS (default: round(log2(M/500)), clipped to be between
0033 %               4 and 8) sets the desired number of clusters to 2^DIVISIONS.  The
0034 %               actual number of clusters may be slightly more/less than this #.
0035 %         OPTIONS.REPS (default: 1) specifies the number of runs of the full
0036 %               k-means solution.  The function will return the assignments that
0037 %               resulted in the minimum MSE.
0038 %         OPTIONS.REASSIGN_CONVERGE (default: 0) defines a convergence condition
0039 %               by specifying the max number of vectors allowed to be reassigned
0040 %               in an EM step.  If <= this number of vectors is reassigned, the
0041 %               this condition is met.
0042 %         OPTIONS.MSE_CONVERGE (default: 0) defines a second convergence condition.
0043 %               If the fractional change in mean squared error from one iteration
0044 %               to the next is smaller than this value, this condition is met.
0045 %
0046 %     NOTE: Iteration stops when either of the convergence conditions is met.
0047 %
0048 % References:
0049 %     Duda RO et al (2001).  _Pattern Classification_, Wiley-Interscience
0050 
0051 %   Last Modified By: sbm on Sun Aug 13 02:28:36 2006
0052 
0053 % Undocumented options:
0054 %       OPTIONS.PROGRESS (default: 1) determines whether the progress
0055 %                                     bar is displayed during the clustering.
0056 %       OPTIONS.SPLITMORE (default: 1) determines whether to do extra cluster
0057 %                                       splitting to try for smaller clusters
0058 
0059 starttime = clock;
0060 
0061 % save the random number seeds before we do anything, in case we want to
0062 % replicate this run exactly ...
0063 randnstate = randn('state');  randstate = rand('state');
0064 
0065 %%%%%%%%%% ARGUMENT CHECKING
0066 if (~isfield(spikes, 'waveforms') | (size(spikes.waveforms, 1) < 1))
0067     error('SS:waveforms_undefined', 'The SS object does not contain any waveforms!');
0068 end
0069 
0070 %%%%%%%%%% CONSTANTS
0071 target_clustersize = 500;
0072 waves = spikes.waveforms;         % ref w/o struct is better for R13 acceleration
0073 [M,N] = size(waves);
0074 jitter = meandist_estim(waves) / 100 / N;        % heuristic
0075 
0076 %%%%%%%%%% DEFAULTS
0077 opts.divisions = round(log2(M / target_clustersize));  % power of 2 that gives closest to target_clustersize, but in [4..7]
0078 opts.divisions = max(min(opts.divisions, 8), 4);       %     restrict to 16-256 clusters; heuristic.
0079 opts.memorylimit = M;                   % all spikes at once
0080 opts.memorylimit = 300;                 %
0081 opts.reps = 1;                          % just one repetition
0082 opts.reassign_converge = 0;             % the last few stubborn ones are probably borderline cases anyway ...
0083 opts.reassign_rough = round(0.005*M);   % (no need at all to get full convergence for intermediate runs)
0084 opts.mse_converge = 0;                  % ... and by default, we don't use the mse convergence criterion
0085 opts.progress = 1;
0086 opts.splitmore = 1;
0087 if (nargin > 1)
0088     supplied = lower(fieldnames(options));   % which options did the user specify?
0089     for op = 1:length(supplied)              % copy those over the defaults
0090         opts.(supplied{op}) = options.(supplied{op});  % this is the preferred syntax as of R13 --
0091     end
0092 end
0093 
0094 %%%%%%%%%% CLUSTERING
0095 normsq = sum(waves.^2, 2);
0096 assigns = ones(M, opts.reps);
0097 mse = Inf * ones(1, opts.reps);
0098 clear fast_kmeans;
0099 
0100 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0101 % IMPORTANT NOTE: For BLAS efficiency reasons,  %
0102 %   the waveforms & centroids matrices are      %
0103 %   transposed from their ordinary orientation  %
0104 %   in the following section of code.           %
0105 %   E.g., waveforms is (D samples x N waveforms)%
0106 waves = waves';                                 %
0107 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0108 for rep = 1:opts.reps                                 % TOTAL # REPETITIONS
0109     
0110     centroid = mean(waves, 2);  % always start here
0111     for iter = 1:opts.divisions                       % # K-MEANS SPLITS
0112         itercounter(iter) = 0;
0113         oldassigns = zeros(M, 1);  oldmse = Inf;
0114         assign_converge = 0;       mse_converge = 0;
0115 
0116         if (iter == opts.divisions), progress = opts.progress;  reassign_criterion = opts.reassign_converge; 
0117         else,                        progress = 0;              reassign_criterion = opts.reassign_rough;
0118         end
0119         
0120         if ((iter==opts.divisions) && opts.splitmore)  % do some extra splitting on clusters that look too big
0121             toobig = find(clustersizes > 2*target_clustersize);
0122             splitbig = centroid(:,toobig) + randn(length(toobig),size(centroid,1))';
0123             centroid = [centroid splitbig];
0124         end
0125 
0126         centroid = [centroid centroid] + jitter * randn(2*size(centroid, 2),size(centroid,1))'; % split & jitter
0127         clustersizes = ones(1,size(centroid,2));  % temporary placeholder to indicate no empty clusters at first
0128         
0129         if (opts.progress)
0130             progressBar(0, 1, sprintf('Calculating %d means.', size(centroid,2))); % crude ...
0131         end
0132         
0133         while (~(assign_converge | mse_converge))     % convergence?
0134             %%%%% Clean out empty clusters
0135             centroid(:,[clustersizes == 0]) = [];
0136                         
0137             %%%%% EM STEP
0138             [assignlist, bestdists, clustersizes, centroid] = ...
0139                         fast_kmeans_step(waves, centroid, normsq, opts.memorylimit);
0140                         
0141             %%%%% Compute convergence info
0142             mse(rep) = mean(bestdists);
0143             mse_converge = ((1 - (mse(rep)/oldmse)) <= opts.mse_converge);   % fractional change
0144             oldmse = mse(rep);
0145             
0146             changed_assigns = sum(assignlist ~= oldassigns);
0147             assign_converge = (changed_assigns <= reassign_criterion);   % num waveforms reassigned
0148             if (progress)
0149                 progressBar(((M - changed_assigns)/M).^10, 5); % rough ...
0150             end
0151             oldassigns = assignlist;
0152             itercounter(iter) = itercounter(iter) + 1;
0153         end
0154     end
0155     
0156     % Finally, reassign spikes in singleton clusters to next best ...
0157     junkclust = find([clustersizes <= 1]);   junkspikes = ismember(assignlist, junkclust);
0158     centroid(:,junkclust) = Inf;  % take these out of the running & recompute distances
0159     if (any(junkspikes))
0160         dist_to_rest = pairdist(waves(:,junkspikes)', centroid', 'nosqrt');
0161         [bestdists(junkspikes), assignlist(junkspikes)] = min(dist_to_rest, [], 2);
0162     end
0163     mse(rep) = mean(bestdists);
0164     
0165     assigns(:,rep) = assignlist;
0166 end
0167 
0168 spikes.overcluster.iteration_count = itercounter;
0169 spikes.overcluster.randn_state = randnstate;
0170 spikes.overcluster.rand_state = randnstate;
0171 
0172 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0173 % Waves & centroids now return to their usual   %
0174 % orientation.  E.g., waveforms is again        %
0175 %    N waveforms x D samples                    %
0176 waves = waves';                                 %
0177 centroid = centroid';                           %
0178 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0179 
0180 
0181 % Finish up by selecting the lowest mse over repetitions.
0182 [bestmse, choice] = min(mse);
0183 spikes.overcluster.assigns = sortAssignments(assigns(:,choice));
0184 spikes.overcluster.mse = bestmse;
0185 spikes.overcluster.sqerr = bestdists;
0186 
0187 % We also save the winning cluster centers as a convenience
0188 numclusts = max(spikes.overcluster.assigns);
0189 spikes.overcluster.centroids = zeros(numclusts, N);
0190 for clust = 1:numclusts
0191     members = find(spikes.overcluster.assigns == clust);
0192     spikes.overcluster.centroids(clust,:) = mean(waves(members,:), 1);
0193 end
0194 
0195 % And W, B, T matrices -- easy since T & B are fast to compute and T = W + B)
0196 spikes.overcluster.T = cov(waves);             % normalize everything by M-1, not M
0197 spikes.overcluster.B = cov(spikes.overcluster.centroids(spikes.overcluster.assigns, :));
0198 spikes.overcluster.W = spikes.overcluster.T - spikes.overcluster.B;
0199 
0200 % Finally, assign colors to the spike clusters here for consistency ...
0201 cmap = jetm(numclusts);
0202 spikes.overcluster.colors = cmap(randperm(numclusts),:);
0203 
0204 if (opts.progress), progressBar(1.0, 1, ''); end
0205 spikes.tictoc.kmeans = etime(clock, starttime);

Generated on Sun 13-Aug-2006 11:49:44 by m2html © 2003