Home > chronux_1_1 > spikesort > utility > datatools > knn.m

knn

PURPOSE ^

KNN Finds the K nearest neighbors in a reference matrix.

SYNOPSIS ^

function [neighbors, distances] = knn(data, k)

DESCRIPTION ^

KNN               Finds the K nearest neighbors in a reference matrix.
   NEIGHBORS = KNN(DATA, K) takes the N x D matrix DATA, where each row
   is treated as a D-dimensional point, and finds the K nearest neighbors
   of each point (using a Euclidean distance metric).  NEIGHBORS is an
   N x K matrix of indices into the rows of the DATA such that the jth
   nearest neighbor of the ith row of DATA is NEIGHBORS(i,j).

   [NEIGHBORS, DISTS] = KNN(DATA,K) also returns the matrix DISTS of
   distances such that DISTS(i,j) is the distance:
         || DATA(i,:) - DATA(NEIGHBORS(i,j),:) ||

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SOURCE CODE ^

0001 function [neighbors, distances] = knn(data, k)
0002 %KNN               Finds the K nearest neighbors in a reference matrix.
0003 %   NEIGHBORS = KNN(DATA, K) takes the N x D matrix DATA, where each row
0004 %   is treated as a D-dimensional point, and finds the K nearest neighbors
0005 %   of each point (using a Euclidean distance metric).  NEIGHBORS is an
0006 %   N x K matrix of indices into the rows of the DATA such that the jth
0007 %   nearest neighbor of the ith row of DATA is NEIGHBORS(i,j).
0008 %
0009 %   [NEIGHBORS, DISTS] = KNN(DATA,K) also returns the matrix DISTS of
0010 %   distances such that DISTS(i,j) is the distance:
0011 %         || DATA(i,:) - DATA(NEIGHBORS(i,j),:) ||
0012 
0013 
0014 % Currently uses a naive O(N^2) algorithm, (although at least
0015 % its broken up to be only O(N) memory).
0016 
0017 [N,D] = size(data);
0018 default_chunk = 500;
0019 
0020 if (nargin < 2)
0021     k = 1;
0022 end
0023 % go a chunk at a time so we're not holding the entire N^2 matrix in memory at once
0024 if (N > default_chunk)
0025     chunksize = default_chunk;
0026 else
0027     chunksize = N;
0028 end
0029 
0030 neighbors = zeros(N, k); %  neighbor indices
0031 distances = zeros(N, k);
0032 
0033 for chunkstart = 1:chunksize:N        
0034     chunkfinis = min(chunkstart+chunksize-1, N);
0035     disp(['Examining waveforms ' num2str(chunkstart) ' to ' num2str(chunkfinis) '.']);
0036     
0037     dists = pairdist(data(chunkstart:chunkfinis,:), data, 'nosqrt', 'reuse');
0038     [dists, inds] = sort(dists, 2);
0039 
0040     neighbors(chunkstart:chunkfinis, :) = inds(:, 2:(k+1));
0041     distances(chunkstart:chunkfinis, :) = dists(:, 2:(k+1));
0042 end
0043 
0044 if (nargout > 2),  distances = sqrt(distances);  end;

Generated on Sun 13-Aug-2006 11:49:44 by m2html © 2003