


SS_AGGREGATE ISI-restricted heirarchical cluster aggregation.
SPIKES = SS_AGGREGATE(SPIKES) takes and returns a spike-sorting object
SPIKES after aggregating clusters (requires a previous overclustering
and an interface energy matrix calculation). The aggregation tree is
stored in SPIKES.HIERARCHY.TREE and the new assignments are stored in
SPIKES.HIERARCHY.ASSIGNS.
The algorithm computes a similarity/connection matrix using the interface
energy. It then chooses the cluster pair with the highest connection
strength, aggregates them, recalculates connection strengths, and then
repeats the process. Cluster aggregation is contingent on passing an
interspike interval (ISI) test if SPIKES.SPIKETIMES is defined; if this
test is not passed, the pair is not aggregated and aggregation continues.
Aggregation stops when either (1) all remaining pairs fail the ISI test
or (2) the connection strength drops below a (heuristic) cutoff of 0.01.
The ISI test needs an idea of the expected refractory period to test
for violations; as a default, it uses 1.5 times the width of the
waveforms in the SPIKES structure. To override this, define the
value SPIKES.OPTIONS.REFRACTORY_PERIOD with units of seconds (e.g.,
set to 0.0017 to indicate an expected refractory period of 1.7 msec).
The SPIKES.HIERARCHY.TREE output is a matrix describing the aggregation.
Each aggregation step entry produces a row, listed in the order they
were performed. The first two columns are the indices of the clusters
that were aggregated; the index assigned to the aggregate for future
entries is the lower of the two original indices. The third column is
the connection strength between the clusters before aggregation and the
fourth column is the isi statistic for the aggregate (0 if isi statistics
are not being used).
After aggregation, outliers that were previously removed are typically
reinserted into the spikes list so that the list aligns with the original
(pre-sorted) list. The outlier waveforms and spike times are thus by
default added back to the SPIKES.WAVEFORMS and SPIKES.SPIKETIMES fields
respectively, after all other aggregation is complete. These waveforms
are assigned the label 0 in SPIKES.HIERARCHY.ASSIGNS and the other,
non-outlier, spikes are renumbered accordingly. To prevent this from
occuring, pass in 0 as the second argument to this function, i.e.,
SPIKES = SS_AGGREGATE(SPIKES, 0);
References:
Fee MS et al (1996). J. Neurosci Methods (69): 175-88

0001 function spikes = ss_aggregate(spikes, reintegrate_outliers) 0002 % SS_AGGREGATE ISI-restricted heirarchical cluster aggregation. 0003 % SPIKES = SS_AGGREGATE(SPIKES) takes and returns a spike-sorting object 0004 % SPIKES after aggregating clusters (requires a previous overclustering 0005 % and an interface energy matrix calculation). The aggregation tree is 0006 % stored in SPIKES.HIERARCHY.TREE and the new assignments are stored in 0007 % SPIKES.HIERARCHY.ASSIGNS. 0008 % 0009 % The algorithm computes a similarity/connection matrix using the interface 0010 % energy. It then chooses the cluster pair with the highest connection 0011 % strength, aggregates them, recalculates connection strengths, and then 0012 % repeats the process. Cluster aggregation is contingent on passing an 0013 % interspike interval (ISI) test if SPIKES.SPIKETIMES is defined; if this 0014 % test is not passed, the pair is not aggregated and aggregation continues. 0015 % Aggregation stops when either (1) all remaining pairs fail the ISI test 0016 % or (2) the connection strength drops below a (heuristic) cutoff of 0.01. 0017 % 0018 % The ISI test needs an idea of the expected refractory period to test 0019 % for violations; as a default, it uses 1.5 times the width of the 0020 % waveforms in the SPIKES structure. To override this, define the 0021 % value SPIKES.OPTIONS.REFRACTORY_PERIOD with units of seconds (e.g., 0022 % set to 0.0017 to indicate an expected refractory period of 1.7 msec). 0023 % 0024 % The SPIKES.HIERARCHY.TREE output is a matrix describing the aggregation. 0025 % Each aggregation step entry produces a row, listed in the order they 0026 % were performed. The first two columns are the indices of the clusters 0027 % that were aggregated; the index assigned to the aggregate for future 0028 % entries is the lower of the two original indices. The third column is 0029 % the connection strength between the clusters before aggregation and the 0030 % fourth column is the isi statistic for the aggregate (0 if isi statistics 0031 % are not being used). 0032 % 0033 % After aggregation, outliers that were previously removed are typically 0034 % reinserted into the spikes list so that the list aligns with the original 0035 % (pre-sorted) list. The outlier waveforms and spike times are thus by 0036 % default added back to the SPIKES.WAVEFORMS and SPIKES.SPIKETIMES fields 0037 % respectively, after all other aggregation is complete. These waveforms 0038 % are assigned the label 0 in SPIKES.HIERARCHY.ASSIGNS and the other, 0039 % non-outlier, spikes are renumbered accordingly. To prevent this from 0040 % occuring, pass in 0 as the second argument to this function, i.e., 0041 % SPIKES = SS_AGGREGATE(SPIKES, 0); 0042 % 0043 % References: 0044 % Fee MS et al (1996). J. Neurosci Methods (69): 175-88 0045 0046 debug = 1; 0047 0048 starttime = clock; 0049 0050 cutoff = 0.01; % arbitrarily stop aggregation when overlap density is < 1% of main cluster density 0051 0052 s = warning('off', 'MATLAB:divideByZero'); 0053 0054 %%%%% ARGUMENT CHECKING 0055 if (~isfield(spikes, 'hierarchy') || ~isfield(spikes.hierarchy, 'interface_energy')) 0056 error('SS:energy_not_computed', 'An energy matrix must be computed before aggregation.'); 0057 elseif (~isfield(spikes, 'overcluster') || ~isfield(spikes.overcluster, 'assigns')) 0058 error('SS:overcluster_not_computed', 'The data must be overclustered before aggregation.'); 0059 elseif (isfield(spikes, 'spiketimes') && ~isfield(spikes, 'Fs')) 0060 error('SS:bad_stop_condition', 'A sampling frequency Fs must be supplied for an ISI stop condition.'); 0061 end 0062 if (isfield(spikes.hierarchy, 'assigns') && any(spikes.hierarchy.assigns == 0)) 0063 error('SS:aggregate_after_outliers', 'Aggregation can not be performed after outliers are reintegrated into the data.'); 0064 end 0065 if (nargin < 2), reintegrate_outliers = 1; end 0066 0067 tmin = size(spikes.waveforms,2)./spikes.Fs; % minimum possible time btw spikes 0068 if (isfield(spikes, 'options') && isfield(spikes.options, 'refractory_period')) 0069 tref = spikes.options.refractory_period; 0070 else 0071 tref = max(0.002, tmin*1.5); % crude guess as to refractory period 0072 end 0073 0074 %%%%% INITIALIZE A FEW THINGS 0075 assignments = spikes.overcluster.assigns; 0076 interface_energy = spikes.hierarchy.interface_energy; 0077 numclusts = max(assignments); 0078 numpts = full(sparse(assignments, 1, 1, numclusts, 1)); 0079 tree = []; 0080 untested = 3*ones(numclusts); % they're all initially untested 0081 0082 % Energy merging progress grid ... 0083 handle_fig = figure; handle_img = imagesc(untested); axis square; 0084 colormap([0 0 0; 0.9 0.4 0.4; 1 1 0; 0.4 0.8 0.2]); % [bk, rd, yl, gn] => (0 combined, 1 not allowed, 2 testing, 3 untested) 0085 0086 %%%%% AGGREGATE HIGHEST CONNECTION STRENGTHS UNTIL ALL TRIED 0087 while (any(any(triu(untested,1)==3))) % only the upper triangular half is meaningful 0088 % compute connection strengths from interface energies 0089 % first, normalize energy: 0090 normalize = ((numpts * numpts') - diag(numpts)); % Off diag: Na*Nb, On diag: Na^2-Na ... 0091 normalize = normalize - diag(0.5*diag(normalize)); % ... and divide diagonal by 2 0092 norm_energy = interface_energy ./ normalize; 0093 % then, compute connection strength 0094 self = repmat(diag(norm_energy), [1,numclusts]); 0095 connect_strength = 2 .* norm_energy ./ (self + self'); 0096 connect_strength = connect_strength .* (1-eye(numclusts)); % diag entries <- 0, so we won't agg clusters with themselves 0097 0098 % Find best remaining pair 0099 remaining = ((untested == 3) .* connect_strength); 0100 best = max(remaining(:)); % highest untested connection strength 0101 0102 if (best < cutoff) % No point continuing if connection strengths have gotten really lousy 0103 break; 0104 end 0105 0106 [clust1 clust2] = find(connect_strength == best); % who're the lucky winners? 0107 first = min(clust1(1),clust2(1)); % if we get 2 best pairs, just take the 1st 0108 second = max(clust1(1),clust2(1)); 0109 untested(first,second) = 2; untested(first,second) = 2; % mark that we're trying this one 0110 set(handle_img, 'CData', triu(untested)); title(['Trying ' num2str(first) ' and ' num2str(second)]); drawnow; 0111 0112 % Is this aggregation allowed? 0113 if (isfield(spikes, 'spiketimes')) % if we were given spike times, use them ... 0114 t1 = spikes.spiketimes(assignments == first); 0115 t2 = spikes.spiketimes(assignments == second); 0116 if (debug) 0117 score = sum((diff(sort([t1; t2])) < tref)) ./ (length(t1) + length(t2)); 0118 allow = (score < 0.05); 0119 scores = [0 0 score]; 0120 else 0121 [allow, scores] = isiQuality(t1, t2, tmin, 0.010, tref, spikes.Fs); 0122 end 0123 else % ... otherwise, there are no restrictions on aggregation 0124 allow = 1; 0125 scores = [0 0 0]; 0126 end 0127 0128 if (allow) % Bookkeeping ... 0129 % Aggregation subsumes the higher index cluster into the lower. Start by adding 0130 % (denormalized) interaction energies for the second (higher index) cluster 0131 % to those of the first and zeroing the old entries of the second. Because of the 0132 % triangular structure of the connection matrix, repeat for both rows and columns ... 0133 interface_energy(first,:) = interface_energy(first,:) + interface_energy(second,:); 0134 interface_energy(second,:) = 0; 0135 interface_energy(:,first) = interface_energy(:,first) + interface_energy(:,second); 0136 interface_energy(:,second) = 0; 0137 interface_energy(second,second) = 1; % keep self-energy at 1 (we may divide by it later) 0138 % since we added rows & columns, some energy values will have spilled over into the 0139 % lower half of the energy matrix (which must be upper triangular). The next 2 steps 0140 % recover those values. 0141 overflow = tril(interface_energy, -1); % spillover below diagonal 0142 interface_energy = interface_energy + overflow' - overflow; % reflect above diagonal 0143 0144 % update counts vector 0145 numpts(first) = numpts(first) + numpts(second); 0146 numpts(second) = 2; % leaving this as 2 prevents div by zero during normalization above 0147 0148 % make a tree entry for the aggregation we just performed 0149 tree = cat(1, tree, [first, second, best, scores(3)]); 0150 0151 % Now actually change the numbers 0152 assignments(assignments == second) = first; 0153 0154 % Finally, indicate that potential aggregations between the new cluster and 0155 % other (nonempty) clusters are untested while pairs involving clusters that 0156 % have already been emptied should not be tested. 0157 untested(first,:) = 3; untested(:,first) = 3; 0158 untested(tree(:,2),:) = 0; untested(:,tree(:,2)) = 0; 0159 else 0160 untested(first,second) = 1; untested(second,first) = 1; 0161 end 0162 end 0163 close(handle_fig); 0164 0165 %spikes.hierarchy.interface_energy_aggregated = interface_energy; 0166 spikes.hierarchy.tree = tree; 0167 spikes.hierarchy.assigns = assignments; 0168 0169 if (reintegrate_outliers && isfield(spikes, 'outliers') && ~isempty(spikes.outliers.badinds)) 0170 % First, we make room by putting all of the non-outliers back into their original places 0171 spikes.waveforms(spikes.outliers.goodinds,:) = spikes.waveforms; 0172 spikes.spiketimes(spikes.outliers.goodinds,:) = spikes.spiketimes; 0173 spikes.hierarchy.assigns(spikes.outliers.goodinds) = spikes.hierarchy.assigns; 0174 0175 % Then we fill in the outliers ... 0176 spikes.waveforms(spikes.outliers.badinds,:) = spikes.outliers.waveforms; 0177 spikes.spiketimes(spikes.outliers.badinds,:) = spikes.outliers.spiketimes; 0178 spikes.hierarchy.assigns(spikes.outliers.badinds) = 0; % ... and add the '0' label. 0179 0180 % We'll also want to add the assignments to the 'overcluster' list (this is 0181 % important for post-clustering splitting). 0182 spikes.overcluster.assigns(spikes.outliers.goodinds) = spikes.overcluster.assigns; 0183 spikes.overcluster.assigns(spikes.outliers.badinds) = 0; 0184 0185 spikes = rmfield(spikes, 'outliers'); % don't need this any more -- its redundant. 0186 end 0187 0188 spikes.tictoc.aggregate = etime(clock, starttime); 0189 0190 0191 warning(s);