Home > chronux_1_50 > spikesort > ss_kmeans.m

ss_kmeans

PURPOSE ^

SS_KMEANS K-means clustering.

SYNOPSIS ^

function spikes = ss_kmeans(spikes, options)

DESCRIPTION ^

 SS_KMEANS  K-means clustering.
     SPIKES = SS_KMEANS(SPIKES) takes and returns a spike-sorting object SPIKES.
     It k-means clusters the (M x N) matrix SPIKES.WAVEFORMS and stores the
     resulting group assignments in SPIKES.OVERCLUSTER.ASSIGNS, the cluster
     centroids in SPIKES.OVERCLUSTER.CENTROIDS, and the mean squared error in
     SPIKES.OVERCLUSTER.MSE.  The W, B, and T matrices (within-cluster, between-
     cluster, and total) covariance matrices are in SPIKES.OVERCLUSTER.W, etc.

     K-means algorithm is an EM-like algorithm that finds K cluster centers in
     N-dimensional space and assigns each of the M data vectors to one of these
     K points.  The process is iterative: new cluster centers are calculated as
     the mean of their previously assigned vectors and vectors are then each 
     reassigned to their closest cluster center.  This finds a local minimum for
     the mean squared distance from each point to its cluster center; this is
     also a (local) MLE for the model of K isotropic gaussian clusters.
 
     This method is highly outlier sensitive; MLE is not robust to the addition
     of a few waveforms that are not like the others and will shift the 'true'
     cluster centers (or add new ones) to account for these points.  See
     SS_OUTLIERS for one solution.
 
     The algorithm used here speeds convergence by first solving for 2 means,
     then using these means (slightly jittered) as starting points for a 4-means
     solution.  This continues for log2(K) steps until K-means have been found.

     SPIKES = SS_KMEANS(SPIKES, OPTIONS) allows specification of clustering
     parameters.  OPTIONS is a structure with some/all of the following fields
     defined.  (Any OPTIONS fields left undefined (or all fields if no OPTIONS
     structure is passed in) uses its default value.)

         OPTIONS.DIVISIONS (default: round(log2(M/500)), clipped to be between
               4 and 8) sets the desired number of clusters to 2^DIVISIONS.  The
               actual number of clusters may be slightly more/less than this #.
         OPTIONS.REPS (default: 1) specifies the number of runs of the full
               k-means solution.  The function will return the assignments that
               resulted in the minimum MSE.
         OPTIONS.REASSIGN_CONVERGE (default: 0) defines a convergence condition
               by specifying the max number of vectors allowed to be reassigned
               in an EM step.  If <= this number of vectors is reassigned, the
               this condition is met.
         OPTIONS.MSE_CONVERGE (default: 0) defines a second convergence condition.
               If the fractional change in mean squared error from one iteration
               to the next is smaller than this value, this condition is met.

     NOTE: Iteration stops when either of the convergence conditions is met.
  
 References:
     Duda RO et al (2001).  _Pattern Classification_, Wiley-Interscience

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function spikes = ss_kmeans(spikes, options)
0002 % SS_KMEANS  K-means clustering.
0003 %     SPIKES = SS_KMEANS(SPIKES) takes and returns a spike-sorting object SPIKES.
0004 %     It k-means clusters the (M x N) matrix SPIKES.WAVEFORMS and stores the
0005 %     resulting group assignments in SPIKES.OVERCLUSTER.ASSIGNS, the cluster
0006 %     centroids in SPIKES.OVERCLUSTER.CENTROIDS, and the mean squared error in
0007 %     SPIKES.OVERCLUSTER.MSE.  The W, B, and T matrices (within-cluster, between-
0008 %     cluster, and total) covariance matrices are in SPIKES.OVERCLUSTER.W, etc.
0009 %
0010 %     K-means algorithm is an EM-like algorithm that finds K cluster centers in
0011 %     N-dimensional space and assigns each of the M data vectors to one of these
0012 %     K points.  The process is iterative: new cluster centers are calculated as
0013 %     the mean of their previously assigned vectors and vectors are then each
0014 %     reassigned to their closest cluster center.  This finds a local minimum for
0015 %     the mean squared distance from each point to its cluster center; this is
0016 %     also a (local) MLE for the model of K isotropic gaussian clusters.
0017 %
0018 %     This method is highly outlier sensitive; MLE is not robust to the addition
0019 %     of a few waveforms that are not like the others and will shift the 'true'
0020 %     cluster centers (or add new ones) to account for these points.  See
0021 %     SS_OUTLIERS for one solution.
0022 %
0023 %     The algorithm used here speeds convergence by first solving for 2 means,
0024 %     then using these means (slightly jittered) as starting points for a 4-means
0025 %     solution.  This continues for log2(K) steps until K-means have been found.
0026 %
0027 %     SPIKES = SS_KMEANS(SPIKES, OPTIONS) allows specification of clustering
0028 %     parameters.  OPTIONS is a structure with some/all of the following fields
0029 %     defined.  (Any OPTIONS fields left undefined (or all fields if no OPTIONS
0030 %     structure is passed in) uses its default value.)
0031 %
0032 %         OPTIONS.DIVISIONS (default: round(log2(M/500)), clipped to be between
0033 %               4 and 8) sets the desired number of clusters to 2^DIVISIONS.  The
0034 %               actual number of clusters may be slightly more/less than this #.
0035 %         OPTIONS.REPS (default: 1) specifies the number of runs of the full
0036 %               k-means solution.  The function will return the assignments that
0037 %               resulted in the minimum MSE.
0038 %         OPTIONS.REASSIGN_CONVERGE (default: 0) defines a convergence condition
0039 %               by specifying the max number of vectors allowed to be reassigned
0040 %               in an EM step.  If <= this number of vectors is reassigned, the
0041 %               this condition is met.
0042 %         OPTIONS.MSE_CONVERGE (default: 0) defines a second convergence condition.
0043 %               If the fractional change in mean squared error from one iteration
0044 %               to the next is smaller than this value, this condition is met.
0045 %
0046 %     NOTE: Iteration stops when either of the convergence conditions is met.
0047 %
0048 % References:
0049 %     Duda RO et al (2001).  _Pattern Classification_, Wiley-Interscience
0050 
0051 %   Last Modified By: sbm on Sun Aug 13 02:28:36 2006
0052 
0053 % Undocumented options:
0054 %       OPTIONS.PROGRESS (default: 1) determines whether the progress
0055 %                                     bar is displayed during the clustering.
0056 %       OPTIONS.SPLITMORE (default: 1) determines whether to do extra cluster
0057 %                                       splitting to try for smaller clusters
0058 
0059 starttime = clock;
0060 
0061 % save the random number seeds before we do anything, in case we want to
0062 % replicate this run exactly ...
0063 randnstate = randn('state');  randstate = rand('state');
0064 
0065 %%%%%%%%%% ARGUMENT CHECKING
0066 if (~isfield(spikes, 'waveforms') || (size(spikes.waveforms, 1) < 1))
0067     error('SS:waveforms_undefined', 'The SS object does not contain any waveforms!');
0068 end
0069 
0070 %%%%%%%%%% CONSTANTS
0071 target_clustersize = 500;
0072 waves = spikes.waveforms;         % ref w/o struct is better for R13 acceleration
0073 [M,N] = size(waves);
0074 jitter = meandist_estim(waves) / 100 / N;        % heuristic
0075 
0076 %%%%%%%%%% DEFAULTS
0077 opts.divisions = round(log2(M / target_clustersize));  % power of 2 that gives closest to target_clustersize, but in [4..7]
0078 opts.divisions = max(min(opts.divisions, 8), 4);       %     restrict to 16-256 clusters; heuristic.
0079 opts.memorylimit = M;                   % all spikes at once
0080 opts.memorylimit = 300;                 %
0081 opts.reps = 1;                          % just one repetition
0082 opts.reassign_converge = 0;             % the last few stubborn ones are probably borderline cases anyway ...
0083 opts.reassign_rough = round(0.005*M);   % (no need at all to get full convergence for intermediate runs)
0084 opts.mse_converge = 0;                  % ... and by default, we don't use the mse convergence criterion
0085 opts.progress = 1;
0086 opts.splitmore = 1;
0087 if (nargin > 1)
0088     supplied = lower(fieldnames(options));   % which options did the user specify?
0089     for op = 1:length(supplied)              % copy those over the defaults
0090         opts.(supplied{op}) = options.(supplied{op});  % this is the preferred syntax as of R13 --
0091     end
0092 end
0093 
0094 %%%%%%%%%% CLUSTERING
0095 normsq = sum(waves.^2, 2);
0096 assigns = ones(M, opts.reps);
0097 mse = Inf * ones(1, opts.reps);
0098 clear fast_kmeans;
0099 
0100 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0101 % IMPORTANT NOTE: For BLAS efficiency reasons,  %
0102 %   the waveforms & centroids matrices are      %
0103 %   transposed from their ordinary orientation  %
0104 %   in the following section of code.           %
0105 %   E.g., waveforms is (D samples x N waveforms)%
0106 waves = waves';                                 %
0107 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0108 for rep = 1:opts.reps                                 % TOTAL # REPETITIONS
0109     
0110     centroid = mean(waves, 2);  % always start here
0111     itercounter = zeros(1,opts.divisions);
0112     for iter = 1:opts.divisions                       % # K-MEANS SPLITS
0113         itercounter(iter) = 0;
0114         oldassigns = zeros(M, 1);  oldmse = Inf;
0115         assign_converge = 0;       mse_converge = 0;
0116         
0117         if (iter == opts.divisions)
0118             progress = opts.progress;  reassign_criterion = opts.reassign_converge; 
0119         else
0120             progress = 0;              reassign_criterion = opts.reassign_rough;
0121         end
0122 
0123         if ((iter==opts.divisions) && opts.splitmore)  % do some extra splitting on clusters that look too big
0124             toobig = find(clustersizes > 2*target_clustersize);
0125             splitbig = centroid(:,toobig) + randn(length(toobig),size(centroid,1))';
0126             centroid = [centroid splitbig];
0127         end
0128 
0129         centroid = [centroid centroid] + jitter * randn(2*size(centroid, 2),size(centroid,1))'; % split & jitter
0130         clustersizes = ones(1,size(centroid,2));  % temporary placeholder to indicate no empty clusters at first
0131         
0132         if (opts.progress)
0133             progressBar(0, 1, sprintf('Calculating %d means.', size(centroid,2))); % crude ...
0134         end
0135 
0136         while (~(assign_converge || mse_converge))     % convergence?
0137             %%%%% Clean out empty clusters
0138             centroid(:,(clustersizes == 0)) = [];
0139                         
0140             %%%%% EM STEP
0141             [assignlist, bestdists, clustersizes, centroid] = ...
0142                         fast_kmeans_step(waves, centroid, normsq, opts.memorylimit);
0143                         
0144             %%%%% Compute convergence info
0145             mse(rep) = mean(bestdists);
0146             mse_converge = ((1 - (mse(rep)/oldmse)) <= opts.mse_converge);   % fractional change
0147             oldmse = mse(rep);
0148             
0149             changed_assigns = sum(assignlist ~= oldassigns);
0150             assign_converge = (changed_assigns <= reassign_criterion);   % num waveforms reassigned
0151             if (progress)
0152                 progressBar(((M - changed_assigns)/M).^10, 5); % rough ...
0153             end
0154             oldassigns = assignlist;
0155             itercounter(iter) = itercounter(iter) + 1;
0156         end
0157     end
0158     
0159     % Finally, reassign spikes in singleton clusters to next best ...
0160     junkclust = find((clustersizes <= 1));   junkspikes = ismember(assignlist, junkclust);
0161     centroid(:,junkclust) = Inf;  % take these out of the running & recompute distances
0162     if (any(junkspikes))
0163         dist_to_rest = pairdist(waves(:,junkspikes)', centroid', 'nosqrt');
0164         [bestdists(junkspikes), assignlist(junkspikes)] = min(dist_to_rest, [], 2);
0165     end
0166     mse(rep) = mean(bestdists);
0167     
0168     assigns(:,rep) = assignlist;
0169 end
0170 
0171 spikes.overcluster.iteration_count = itercounter;
0172 spikes.overcluster.randn_state = randnstate;
0173 spikes.overcluster.rand_state = randstate;
0174 
0175 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0176 % Waves & centroids now return to their usual   %
0177 % orientation.  E.g., waveforms is again        %
0178 %    N waveforms x D samples                    %
0179 waves = waves';                                 %
0180 centroid = centroid';                           %
0181 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
0182 
0183 
0184 % Finish up by selecting the lowest mse over repetitions.
0185 [bestmse, choice] = min(mse);
0186 spikes.overcluster.assigns = sortAssignments(assigns(:,choice));
0187 spikes.overcluster.mse = bestmse;
0188 spikes.overcluster.sqerr = bestdists;
0189 
0190 % We also save the winning cluster centers as a convenience
0191 numclusts = max(spikes.overcluster.assigns);
0192 spikes.overcluster.centroids = zeros(numclusts, N);
0193 for clust = 1:numclusts
0194     members = find(spikes.overcluster.assigns == clust);
0195     spikes.overcluster.centroids(clust,:) = mean(waves(members,:), 1);
0196 end
0197 
0198 % And W, B, T matrices -- easy since T & B are fast to compute and T = W + B)
0199 spikes.overcluster.T = cov(waves);             % normalize everything by M-1, not M
0200 spikes.overcluster.B = cov(spikes.overcluster.centroids(spikes.overcluster.assigns, :));
0201 spikes.overcluster.W = spikes.overcluster.T - spikes.overcluster.B;
0202 
0203 % Finally, assign colors to the spike clusters here for consistency ...
0204 cmap = jetm(numclusts);
0205 spikes.overcluster.colors = cmap(randperm(numclusts),:);
0206 
0207 if (opts.progress), progressBar(1.0, 1, ''); end
0208 spikes.tictoc.kmeans = etime(clock, starttime);

Generated on Mon 09-Oct-2006 00:54:52 by m2html © 2003