


KNN Finds the K nearest neighbors in a reference matrix.
NEIGHBORS = KNN(DATA, K) takes the N x D matrix DATA, where each row
is treated as a D-dimensional point, and finds the K nearest neighbors
of each point (using a Euclidean distance metric). NEIGHBORS is an
N x K matrix of indices into the rows of the DATA such that the jth
nearest neighbor of the ith row of DATA is NEIGHBORS(i,j).
[NEIGHBORS, DISTS] = KNN(DATA,K) also returns the matrix DISTS of
distances such that DISTS(i,j) is the distance:
|| DATA(i,:) - DATA(NEIGHBORS(i,j),:) ||

0001 function [neighbors, distances] = knn(data, k) 0002 %KNN Finds the K nearest neighbors in a reference matrix. 0003 % NEIGHBORS = KNN(DATA, K) takes the N x D matrix DATA, where each row 0004 % is treated as a D-dimensional point, and finds the K nearest neighbors 0005 % of each point (using a Euclidean distance metric). NEIGHBORS is an 0006 % N x K matrix of indices into the rows of the DATA such that the jth 0007 % nearest neighbor of the ith row of DATA is NEIGHBORS(i,j). 0008 % 0009 % [NEIGHBORS, DISTS] = KNN(DATA,K) also returns the matrix DISTS of 0010 % distances such that DISTS(i,j) is the distance: 0011 % || DATA(i,:) - DATA(NEIGHBORS(i,j),:) || 0012 0013 % Currently uses a naive O(N^2) algorithm, (although at least 0014 % its broken up to be only O(N) memory). 0015 0016 [N,D] = size(data); 0017 default_chunk = 500; 0018 0019 if (nargin < 2) 0020 k = 1; 0021 end 0022 % go a chunk at a time so we're not holding the entire N^2 matrix in memory at once 0023 if (N > default_chunk) 0024 chunksize = default_chunk; 0025 else 0026 chunksize = N; 0027 end 0028 0029 neighbors = zeros(N, k); % neighbor indices 0030 distances = zeros(N, k); 0031 0032 for chunkstart = 1:chunksize:N 0033 chunkfinis = min(chunkstart+chunksize-1, N); 0034 disp(['Examining waveforms ' num2str(chunkstart) ' to ' num2str(chunkfinis) '.']); 0035 0036 dists = pairdist(data(chunkstart:chunkfinis,:), data, 'nosqrt', 'reuse'); 0037 [dists, inds] = sort(dists, 2); 0038 0039 neighbors(chunkstart:chunkfinis, :) = inds(:, 2:(k+1)); 0040 distances(chunkstart:chunkfinis, :) = dists(:, 2:(k+1)); 0041 end 0042 0043 if (nargout > 2), distances = sqrt(distances); end;