Nearest Neighbor search is used to find objects that are similar to each other. The idea is that given an input, NN search finds the objects in our database that are similar to the input. As a simple example, if you had a database of news articles and you want to retrieve news similar to your query then you would perform a nearest neighbors search for you input query against all articles in your database and return top 10 results.

In NN search, distance function is important. It decides how similar or dissimilar the objects are. Lower distance values indicate the objects are similar where as higher indicate they are dissimilar. For example, if two objects have distance 0 then they are identical.

Data processing

To keep things simple, we’ll use a dataset provided by scikit-learn library called 20 newsgroups. The 20 newsgroups dataset comprises around 18000 newsgroups posts on 20 topics. The code below will load the data.

from sklearn.datasets import fetch_20newsgroups

bunch = fetch_20newsgroups(remove='headers')

print(type(bunch), bunch.keys())
# (sklearn.utils.Bunch, dict_keys(['data', 'filenames', 'target_names', 'target', 'DESCR']))

The output is a basically a dict like object with the keys shown above. For this demo we are only interested in the values under data key. which contains a list of posts. As an example,

' I was wondering if anyone out there could enlighten me on this car I saw\nthe other day. It was a 2-door sports car, looked to be from the late 60s/\nearly 70s. It was called a Bricklin. The doors were really small. In addition,\nthe front bumper was separate from the rest of the body. This is \nall I know. If anyone can tellme a model name, engine specs, years\nof production, where this car is made, history, or whatever info you\nhave on this funky looking car, please e-mail.\n\nThanks,\n- IL\n   ---- brought to you by your neighborhood Lerxst ----\n\n\n\n\n'

Feature extraction

As we know, we need to represent our entities in a vector space in order to use nearest neighbors. Since our entities are texts, we should use some feature extraction technique to extract feature vectors. For this demo, we’ll use Tf-Idf. If you are not sure about feature extraction process for texts, then refer to this article.

To keep things simple, we won’t tune any parameters and use defaults except for max_features.

from sklearn.feature_extraction.text import TfidfVectorizer
vec = TfidfVectorizer(max_features=10_000)
features = vec.fit_transform(
print(features.shape) # (11314, 10000)

Now we have feature vectors for entire dataset where each vector has 10,000 elements. We can finally train a nearest neighbors model.

Model training

This step is also very simple. You just need to instantiate a NearestNeighbors object and train it by calling the fit function.

Important Parameters

Here are some important parameters that you might need to change depending on your needs.

n_neighbors : Number of neighbors to use by default for kneighbors queries.

algorithm : Algorithm used to compute the nearest neighbors:

  • ‘ball_tree’ will use :class:BallTree
  • ‘kd_tree’ will use :class:KDTree
  • ‘brute’ will use a brute-force search.
  • ‘auto’ will attempt to decide the most appropriate algorithm based on the values passed to fit method.

metric : metric to use for distance computation. Any metric from scikit-learn or scipy.spatial.distance can be used. If metric is a callable function, it is called on each pair of instances (rows) and the resulting value recorded. The callable should take two arrays as input and return one value indicating the distance between them. This works for Scipy’s metrics, but is less efficient than passing the metric name as a string.

Valid values for metric are:

  • [‘cityblock’, ‘cosine’, ‘euclidean’, ‘l1’, ‘l2’, ‘manhattan’]
  • any function from scipy.spatial.distance
from sklearn.neighbors import NearestNeighbors
knn = NearestNeighbors(n_neighbors=10, metric='cosine')

That is all it takes to train a KNN model. I’ve used cosine as metric because it is generally used for text similarity. euclidean is also a good choice.


Now that we have a KNN model, how to find similar items for a given input text? We first need to convert the text into feature vector using vec.transform function and then give this vector as input to the knn model.

Let’s quickly look at what output we get from knn. To find nearest neighbors, we need to call kneighbors function. The first parameter is a list of feature vectors. If return_distance is False, it only returns a 2D array where each row contains k nearest neighbors indices for each input feature vector.

If return_distance is True, it returns a tuple of 2D arrays. In first array each row contains the distances and in the second array each row contains k nearest neighbors indices for each input feature vector.

knn.kneighbors(features[0:1], return_distance=False)
# array([[   0,  958, 8013, 8266,  659, 5553, 3819, 2554, 6055, 7993]])

knn.kneighbors(features[0:1], return_distance=True)
# (array([[0.        , 0.35119023, 0.62822688, 0.64738668, 0.66613124,
#         0.67267273, 0.68149664, 0.68833514, 0.70024449, 0.70169709]]),
# array([[   0,  958, 8013, 8266,  659, 5553, 3819, 2554, 6055, 7993]]))

Note that the indices returned are the index of feature vectors use while training the model Also, almost all functions like fit, transform, kneighbors etc. in sklearn expects 2D array as input so we passed features[0:1] as input rather than just features[0].

Finally, the code below shows how you can take raw input texts, extract the features and find the nearest neighbors. Note here we have specifically specified n_neighbors=2 so that we get 2 neighbors per query. If this value is not specified then it will return 10 neighbors because we specified n_neighbors=10 while instantiating our model above.

input_texts = ["any recommendations for good ftp sites?", "i need to clean my car"]
input_features = vec.transform(input_texts)

D, N = knn.kneighbors(input_features, n_neighbors=2, return_distance=True)

for input_text, distances, neighbors in zip(input_texts, D, N):
    print("Input text = ", input_text[:200], "\n")
    for dist, neighbor_idx in zip(distances, neighbors):
        print("Distance = ", dist, "Neighbor idx = ", neighbor_idx)
Input text =  any recommendations for good ftp sites? 

Distance =  0.5870334253639387 Neighbor idx =  89
I would like to experiment with the INTEL 8051 family.  Does anyone out  
there know of any good FTP sites that might have compiliers, assemblers,  

Distance =  0.6566334116701875 Neighbor idx =  7665

I am looking for ftp sites (where there are freewares or sharewares)
for Mac. It will help a lot if there are driver source codes in those 
ftp sites. Any information is appreciated. 

Thanks in 

Input text =  i need to clean my car 

Distance =  0.6592186982514803 Neighbor idx =  8013
In article <[email protected]> [email protected] (Rhonda Gaines) writes:
>I'm planning on purchasing a new car and will be trading in my '90
>Mazda MX-6 DX.  I've still got 2 more years to pay o
Distance =  0.692693967282819 Neighbor idx =  7993
I bought a car with a defunct engine, to use for parts
for my old but still running version of the same car.

The car I bought has good tires.

Is there anything in particular that I should do to