Home > chronux_2_00 > spikesort > utility > datatools > knn.m

knn

PURPOSE ^

KNN Finds the K nearest neighbors in a reference matrix.

SYNOPSIS ^

function [neighbors, distances] = knn(data, k)

DESCRIPTION ^

KNN               Finds the K nearest neighbors in a reference matrix.
   NEIGHBORS = KNN(DATA, K) takes the N x D matrix DATA, where each row
   is treated as a D-dimensional point, and finds the K nearest neighbors
   of each point (using a Euclidean distance metric).  NEIGHBORS is an
   N x K matrix of indices into the rows of the DATA such that the jth
   nearest neighbor of the ith row of DATA is NEIGHBORS(i,j).

   [NEIGHBORS, DISTS] = KNN(DATA,K) also returns the matrix DISTS of
   distances such that DISTS(i,j) is the distance:
         || DATA(i,:) - DATA(NEIGHBORS(i,j),:) ||

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [neighbors, distances] = knn(data, k)
0002 %KNN               Finds the K nearest neighbors in a reference matrix.
0003 %   NEIGHBORS = KNN(DATA, K) takes the N x D matrix DATA, where each row
0004 %   is treated as a D-dimensional point, and finds the K nearest neighbors
0005 %   of each point (using a Euclidean distance metric).  NEIGHBORS is an
0006 %   N x K matrix of indices into the rows of the DATA such that the jth
0007 %   nearest neighbor of the ith row of DATA is NEIGHBORS(i,j).
0008 %
0009 %   [NEIGHBORS, DISTS] = KNN(DATA,K) also returns the matrix DISTS of
0010 %   distances such that DISTS(i,j) is the distance:
0011 %         || DATA(i,:) - DATA(NEIGHBORS(i,j),:) ||
0012 
0013 % Currently uses a naive O(N^2) algorithm, (although at least
0014 % its broken up to be only O(N) memory).
0015 
0016 [N,D] = size(data);
0017 default_chunk = 500;
0018 
0019 if (nargin < 2)
0020     k = 1;
0021 end
0022 % go a chunk at a time so we're not holding the entire N^2 matrix in memory at once
0023 if (N > default_chunk)
0024     chunksize = default_chunk;
0025 else
0026     chunksize = N;
0027 end
0028 
0029 neighbors = zeros(N, k); %  neighbor indices
0030 distances = zeros(N, k);
0031 
0032 for chunkstart = 1:chunksize:N        
0033     chunkfinis = min(chunkstart+chunksize-1, N);
0034     disp(['Examining waveforms ' num2str(chunkstart) ' to ' num2str(chunkfinis) '.']);
0035     
0036     dists = pairdist(data(chunkstart:chunkfinis,:), data, 'nosqrt', 'reuse');
0037     [dists, inds] = sort(dists, 2);
0038 
0039     neighbors(chunkstart:chunkfinis, :) = inds(:, 2:(k+1));
0040     distances(chunkstart:chunkfinis, :) = dists(:, 2:(k+1));
0041 end
0042 
0043 if (nargout > 2),  distances = sqrt(distances);  end;

Generated on Fri 15-Aug-2008 11:35:42 by m2html © 2003