function ypred = knnclass(xtest,xtrain,ytrain,k);

% KNNCLASS  K-nearest neighbors classification.
%
% Input       xtest     : ninp x ntest test inputs.
%             xtrain    : ninp x ndata training inputs.
%             ytrain    : 1 x ndata training classes.
%             k         : number of nearest neighbors.
% Output      ypred     : 1 x ntest class of nearest neighbor(s).
%
% Remarks  1) xtrain, xtest and ytrain should both be reals.
%          2) Uses unweighted Euclidean distance.

if nargin < 4, k = 1; end

[ninp,ntrain] = size(xtrain);
k = min(ntrain,k); 
    % number of neighbors cannot be larger than number of training points!!

ntest = size(xtest,2);

dist = distance(xtest,xtrain);
  % distance matrix of size ntest-by-ntrain between ntest test inputs
  % and ntrain training inputs
if k==1,
  [dummy,index] = min(dist,[],2);
  ypred = ytrain(index);
  ypred = ypred(:)';
else
  [dummy,index] = sort(dist,2);
  % sorted such that index(i,1:k) gives the indices of the k nearest
  % neighbors of test example i
  index = index(:,1:k);     % only first k matter

  yclass = ytrain(index)';
  ypred = majorityvote(yclass);
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function dmat = distance(a1,a2);

% DISTANCE  Compute distances between column vectors of matrices a1 and a2
%
% Input      a1   : n x m1 matrix.
%            a2   : n x m2 matrix.
% Output     dmat : m1 x m2 distance matrix.
%
% NOTE: distance(a1) is equivalent to distance(a1,a1)

% Tom Heskes

if nargin == 2,
   [n1,m1] = size(a1);
   [n2,m2] = size(a2);
   if (n1 == n2),
      dmat = sum(a1.^2,1)'*ones(1,m2) + ...
              ones(m1,1)*sum(a2.^2,1) - 2*a1'*a2;
   else
      dmat = zeros(m1,m2);
   end
else
   [n1,m1] = size(a1);
   q = a1'*a1;
   dmat = diag(q)*ones(1,m1);
   dmat = dmat+dmat'-2*q;
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [ymaj,uniq] = majorityvote(yclass);

% MAJORITYVOTE  Compute column-wise majority vote.
%
% Input    yclass : k x ndata class assignments.
% Output   ymaj   : 1 x ndata majority votes.
%          uniq   : 1 x ndata uniqueness indicator.
%
% NOTE: in case of a tie, the lowest class is chosen and uniq is set to zero.

% Tom Heskes

ndata = size(yclass,2);
uniqueclass = unique(yclass(:))';
nclass = length(uniqueclass);
equalclass = zeros(nclass,ndata);
for i=1:nclass,
  equalclass(i,:) = sum(yclass == uniqueclass(i),1);
end
[maxec,indec] = max(equalclass,[],1);
ymaj = uniqueclass(indec);

if nargout > 1,
  diff = equalclass-ones(nclass,1)*maxec;
  uniq = (sum(diff==0,1) == 1);
end
