


KNN Finds the K nearest neighbors in a reference matrix.
NEIGHBORS = KNN(DATA, K) takes the N x D matrix DATA, where each row
is treated as a D-dimensional point, and finds the K nearest neighbors
of each point (using a Euclidean distance metric). NEIGHBORS is an
N x K matrix of indices into the rows of the DATA such that the jth
nearest neighbor of the ith row of DATA is NEIGHBORS(i,j).
[NEIGHBORS, DISTS] = KNN(DATA,K) also returns the matrix DISTS of
distances such that DISTS(i,j) is the distance:
|| DATA(i,:) - DATA(NEIGHBORS(i,j),:) ||

0001 function [neighbors, distances] = knn(data, k) 0002 %KNN Finds the K nearest neighbors in a reference matrix. 0003 % NEIGHBORS = KNN(DATA, K) takes the N x D matrix DATA, where each row 0004 % is treated as a D-dimensional point, and finds the K nearest neighbors 0005 % of each point (using a Euclidean distance metric). NEIGHBORS is an 0006 % N x K matrix of indices into the rows of the DATA such that the jth 0007 % nearest neighbor of the ith row of DATA is NEIGHBORS(i,j). 0008 % 0009 % [NEIGHBORS, DISTS] = KNN(DATA,K) also returns the matrix DISTS of 0010 % distances such that DISTS(i,j) is the distance: 0011 % || DATA(i,:) - DATA(NEIGHBORS(i,j),:) || 0012 0013 0014 % Currently uses a naive O(N^2) algorithm, (although at least 0015 % its broken up to be only O(N) memory). 0016 0017 [N,D] = size(data); 0018 default_chunk = 500; 0019 0020 if (nargin < 2) 0021 k = 1; 0022 end 0023 % go a chunk at a time so we're not holding the entire N^2 matrix in memory at once 0024 if (N > default_chunk) 0025 chunksize = default_chunk; 0026 else 0027 chunksize = N; 0028 end 0029 0030 neighbors = zeros(N, k); % neighbor indices 0031 distances = zeros(N, k); 0032 0033 for chunkstart = 1:chunksize:N 0034 chunkfinis = min(chunkstart+chunksize-1, N); 0035 disp(['Examining waveforms ' num2str(chunkstart) ' to ' num2str(chunkfinis) '.']); 0036 0037 dists = pairdist(data(chunkstart:chunkfinis,:), data, 'nosqrt', 'reuse'); 0038 [dists, inds] = sort(dists, 2); 0039 0040 neighbors(chunkstart:chunkfinis, :) = inds(:, 2:(k+1)); 0041 distances(chunkstart:chunkfinis, :) = dists(:, 2:(k+1)); 0042 end 0043 0044 if (nargout > 2), distances = sqrt(distances); end;