I am trying to write an algorithm that can find user-specified nearest neighbors. By user-specified, I mean that the user can specify whether it's a general nearest neighbor, a forward-nearest neighbor, or a backward-nearest neighbor.
The idea for this code was inspired from this SO post. While it isn't ideal to search the entire array (perhaps use searchsorted as an alternative), I want to find all occurrences of the user-specified nearest value in the given data array. While there are other techniques that can be used to achieve the same goal (such as using the cumulative sum of differences of argsorted values), I feel the code below is easier to read/understand and is likely quicker since it performs less operations that require traversing the entire data array. That said, I would like to know if there are better approaches (in terms of speed) to achieve the same output, as this code will be applied to a dataset of at least ~70,000 data points. More than the value itself, I am concerned with the indices at which the values occur.
import numpy as np
Sample Data
sample = np.array([300, 800, 200, 500, 600, 750, 700, 450, 400, 550, 350, 900])
# sample = np.array([300, 800, 200, 500, 600, 750, 700, 450, 400, 550, 350, 900] * 2)
Main algorithm
def search_nearest(data, search_value, direction=None):
"""
This function can find the nearest, forward-nearest, or
backward-nearest value in data relative to the given search value.
"""
if not isinstance(data, np.ndarray):
data = np.array(data)
print("\n>> DATA\n{}\n".format(data))
print(">> SEARCH VALUE\n{}\n".format(search_value))
if direction is None:
delta = np.abs(data - search_value)
res = np.where(delta == np.min(delta))[0]
elif direction == 'forward':
delta = data - search_value
try:
res = np.where(delta == np.min(delta[delta >= 0]))[0]
except:
raise ValueError("no forward nearest match exists")
elif direction == 'backward':
delta = search_value - data
try:
res = np.where(delta == np.min(delta[delta >= 0]))[0]
except:
raise ValueError("no backward nearest match exists")
print(" .. INDEX OF NEAREST NUMBER\n{}\n".format(res))
print(" .. NUMBER AT THAT INDEX\n{}\n".format(data[res]))
print("--------------------")
Call the main function
# crd = None
crd = 'forward'
# crd = 'backward'
for val in (799, 301, 800, 250, 8, 901):
search_nearest(sample, search_value=val, direction=crd)