Home > chronux_1_1 > spikesort > ss_outliers.m

ss_outliers

PURPOSE ^

SS_OUTLIERS K-means based outlier detection.

SYNOPSIS ^

function spikes = ss_outliers(spikes, reps)

DESCRIPTION ^

 SS_OUTLIERS  K-means based outlier detection.
     SPIKES = SS_OUTLIERS(SPIKES) takes and returns a spike-sorting object
     SPIKES.  It identifies likely outliers in the SPIKES.WAVEFORMS and moves
     them from the (initially M x N) SPIKES.WAVEFORMS matrix to a (P x N)
     SPIKES.OUTLIERS.WAVEFORMS (usually P << M), removing them the original
     SPIKES.WAVEFORMS matrix.  If the SPIKES.OUTLIERS.WAVEFORMS matrix already
     exists, new outliers are added to the existing matrix.

     Outlier spike timestamps are recorded in SPIKES.OUTLIERS.SPIKETIMES
     (these event times are removed from SPIKES.SPIKETIMES).  The cell matrix
     SPIKES.OUTLIERS.WHY stores a string describing the reasons that the
     corresponding waveforms was marked as an outlier.
 
     The SS_OUTLIERS function identifies outliers using an ad hoc heuristic
     based on a k-means clustering; waveforms that end up very far from their 
     assigned centroid are likely outliers.   The scale for this determination
     comes from the average of the covariance matrices of each cluster (i.e.,
     1/(N-1) times the (N x N) within-group sum of squares matrix).  We take
     this as an approximation to the noise covariance in outlier-free clusters 
     and note that if the noise were locally Gaussian, then the waveform Mahalanobis
     distances to assigned cluster means should be roughly Chi^2 distributed
     (actually, F-distributed but we ignore the refinement for now).
     NOTE: Outliers damage the k-means solution; the clustering is not robust to
            gross violations of its (local) Gaussian assumption.  Repeat clustering  
            on the cleaned data will thus yield a new solution ... which might
            uncover further outliers.  This function attempts 3 cluster/clean
            iterations (rule of thumb; tradeoff btw cleaning and run-time),
            although it stops after any iteration that does not find an outlier.

     The default Chi^2 CDF cutoff is (1 - 1/(100*M)).  To use a different cutoff,
     specify the value in 'SPIKES.OPTIONS.OUTLIER_CUTOFF'; the closer the value
     (e.g., 1-1e-8) is to 1, the fewer spikes will be considered outliers.

     SPIKES = SS_OUTLIERS(SPIKES, REPS) performs REPS cluster/clean iterations
     (default: 3).  

     The reason given for these outliers in SPIKES.OUTLIERS.WHY is "K-Means Outliers".

     NOTE: This function performs a k-means clustering and thus overwrites
           existing k-means assignments in the SPIKES object.

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function spikes = ss_outliers(spikes, reps)
0002 % SS_OUTLIERS  K-means based outlier detection.
0003 %     SPIKES = SS_OUTLIERS(SPIKES) takes and returns a spike-sorting object
0004 %     SPIKES.  It identifies likely outliers in the SPIKES.WAVEFORMS and moves
0005 %     them from the (initially M x N) SPIKES.WAVEFORMS matrix to a (P x N)
0006 %     SPIKES.OUTLIERS.WAVEFORMS (usually P << M), removing them the original
0007 %     SPIKES.WAVEFORMS matrix.  If the SPIKES.OUTLIERS.WAVEFORMS matrix already
0008 %     exists, new outliers are added to the existing matrix.
0009 %
0010 %     Outlier spike timestamps are recorded in SPIKES.OUTLIERS.SPIKETIMES
0011 %     (these event times are removed from SPIKES.SPIKETIMES).  The cell matrix
0012 %     SPIKES.OUTLIERS.WHY stores a string describing the reasons that the
0013 %     corresponding waveforms was marked as an outlier.
0014 %
0015 %     The SS_OUTLIERS function identifies outliers using an ad hoc heuristic
0016 %     based on a k-means clustering; waveforms that end up very far from their
0017 %     assigned centroid are likely outliers.   The scale for this determination
0018 %     comes from the average of the covariance matrices of each cluster (i.e.,
0019 %     1/(N-1) times the (N x N) within-group sum of squares matrix).  We take
0020 %     this as an approximation to the noise covariance in outlier-free clusters
0021 %     and note that if the noise were locally Gaussian, then the waveform Mahalanobis
0022 %     distances to assigned cluster means should be roughly Chi^2 distributed
0023 %     (actually, F-distributed but we ignore the refinement for now).
0024 %     NOTE: Outliers damage the k-means solution; the clustering is not robust to
0025 %            gross violations of its (local) Gaussian assumption.  Repeat clustering
0026 %            on the cleaned data will thus yield a new solution ... which might
0027 %            uncover further outliers.  This function attempts 3 cluster/clean
0028 %            iterations (rule of thumb; tradeoff btw cleaning and run-time),
0029 %            although it stops after any iteration that does not find an outlier.
0030 %
0031 %     The default Chi^2 CDF cutoff is (1 - 1/(100*M)).  To use a different cutoff,
0032 %     specify the value in 'SPIKES.OPTIONS.OUTLIER_CUTOFF'; the closer the value
0033 %     (e.g., 1-1e-8) is to 1, the fewer spikes will be considered outliers.
0034 %
0035 %     SPIKES = SS_OUTLIERS(SPIKES, REPS) performs REPS cluster/clean iterations
0036 %     (default: 3).
0037 %
0038 %     The reason given for these outliers in SPIKES.OUTLIERS.WHY is "K-Means Outliers".
0039 %
0040 %     NOTE: This function performs a k-means clustering and thus overwrites
0041 %           existing k-means assignments in the SPIKES object.
0042 
0043 %   Last Modified By: sbm on Thu Oct  6 20:15:56 2005
0044 
0045 starttime = clock;
0046 
0047 %%%%% ARGUMENT CHECKING
0048 if (~isfield(spikes, 'waveforms') | (size(spikes.waveforms, 1) < 1))
0049     error('SS:waveforms_undefined', 'The SS object does not contain any waveforms!');
0050 end
0051 
0052 [M,N] = size(spikes.waveforms);
0053 if (isfield(spikes, 'options') && isfield(spikes.options, 'outlier_cutoff'))
0054     cutoff = spikes.options.outlier_cutoff;
0055 else
0056     cutoff = (1 - (1/(100*M)));    
0057 end
0058 
0059 if (nargin < 3),  reps = 3;  end
0060 times_defined = isfield(spikes, 'spiketimes');
0061 
0062 % Initialize the outliers sub-structure
0063 if (~isfield(spikes, 'outliers'))
0064     spikes.outliers.waveforms = [];
0065     spikes.outliers.why = {};
0066     if (times_defined)
0067         spikes.outliers.spiketimes = [];
0068     end
0069     
0070     spikes.outliers.goodinds = [1:M]';  % We need these to re-insert the outlier 'cluster'
0071     spikes.outliers.badinds = [];      % into the waveforms matrix after sorting.
0072 end
0073 
0074 for cleaning = 1:reps   % Cluster/clean then rinse/repeat.  3 reps does a good job w/o taking too much time
0075 
0076     progressBar((cleaning-1)./reps, 1, ['Removing Outliers: Pass ' num2str(cleaning)]);
0077     
0078     %%%%% PERFORM A K-MEANS CLUSTERING OF THE DATA
0079     opts.mse_converge = 0.001;        % rough clustering is fine
0080     opts.progress = 0;
0081     spikes = ss_kmeans(spikes, opts);
0082     
0083     %%%%% MAHALANOBIS DISTANCES:   (x - x_mean)' * S^-1 * (x - x_mean)
0084     vectors_to_centers = spikes.waveforms - spikes.overcluster.centroids(spikes.overcluster.assigns,:);
0085     mahaldists = sum([vectors_to_centers .* (spikes.overcluster.W \ vectors_to_centers')'], 2)';
0086     
0087     %%%%% SPLIT OFF OUTLIERS
0088     bad = find(mahaldists > chi2inv(cutoff, N));   % find putative outliers
0089     
0090     if (length(bad) == 0)
0091         break;              % didn't find anything ... no sense continuing
0092     else
0093         %%%%% ADD OUTLIERS TO SS OBJECT AND REMOVE FROM MAIN LIST
0094         spikes.outliers.waveforms = cat(1, spikes.outliers.waveforms, spikes.waveforms(bad,:));
0095         spikes.outliers.why = cat(1, spikes.outliers.why, repmat({'K-means Outliers'}, [length(bad), 1]));
0096         spikes.outliers.badinds = cat(1, spikes.outliers.badinds, spikes.outliers.goodinds(bad(:)));
0097         if (times_defined)
0098             spikes.outliers.spiketimes = cat(1, spikes.outliers.spiketimes, spikes.spiketimes(bad,:));
0099             spikes.spiketimes(bad,:) = [];
0100         end
0101         spikes.outliers.goodinds(bad) = [];
0102         spikes.waveforms(bad,:) = [];
0103         spikes = rmfield(spikes, 'overcluster');  % this isn't supposed to be public-visible
0104     end
0105 end
0106 
0107 progressBar(1.0, 1, '');
0108 
0109 %%%%% TIMING INFORMATION
0110 spikes.tictoc = rmfield(spikes.tictoc, 'kmeans');
0111 spikes.tictoc.outliers = etime(clock, starttime);

Generated on Sun 13-Aug-2006 11:49:44 by m2html © 2003