Skip to content

Commit 0ef5495

Browse files
committed
Use Sentence Transformers to Encode, Query Schedule.org Headings
0 parents  commit 0ef5495

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

‎.gitignore‎

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.*

‎requirements.txt‎

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
torch==1.6.0
2+
transformers==3.3.1
3+
sentence-transformers==0.3.8
4+
pandas==1.1.2
5+
faiss-cpu==1.6.1
6+
numpy==1.18.5

‎similarity.py‎

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import pandas as pd
2+
import faiss
3+
import numpy as np
4+
5+
from sentence_transformers import SentenceTransformer
6+
7+
import argparse
8+
import os
9+
10+
def create_index(
11+
model,
12+
dataset_path,
13+
index_path,
14+
column_name,
15+
recreate):
16+
# Load Dataset
17+
dataset = pd.read_csv(dataset_path)
18+
19+
# Clean Dataset
20+
dataset = dataset.dropna()
21+
dataset[column_name] = dataset[column_name].str.strip()
22+
23+
# Create Index or Load it if it already exists
24+
if os.path.exists(index_path) and not recreate:
25+
index = faiss.read_index(index_path)
26+
else:
27+
# Create Embedding Vectors of Documents
28+
embeddings = model.encode(dataset[column_name].to_list(), show_progress_bar=True)
29+
embeddings = np.array([embedding for embedding in embeddings]).astype("float32")
30+
31+
index = faiss.IndexIDMap(
32+
faiss.IndexFlatL2(
33+
embeddings.shape[1]))
34+
35+
index.add_with_ids(embeddings, dataset.index.values)
36+
37+
faiss.write_index(index, index_path)
38+
39+
return index, dataset
40+
41+
42+
def resolve_column(dataset, Id, column):
43+
return [list(dataset[dataset.index == idx][column]) for idx in Id[0]]
44+
45+
46+
def vector_search(query, index, dataset, column_name, num_results=10):
47+
query_vector = np.array(query).astype("float32")
48+
D, Id = index.search(query_vector, k=num_results)
49+
50+
return zip(D[0], Id[0], resolve_column(dataset, Id, column_name))
51+
52+
if __name__ == '__main__':
53+
parser = argparse.ArgumentParser(description="Find most suitable match based on users exclude, include preferences")
54+
parser.add_argument('positives', type=str, help="Terms to find closest match to")
55+
parser.add_argument('--negatives', '-n', type=str, help="Terms to find farthest match from")
56+
57+
parser.add_argument('--recreate', action='store_true', default=False, help="Recreate index at index_path from dataset at dataset path")
58+
parser.add_argument('--index', type=str, default="./.faiss_index", help="Path to index for storing vector embeddings")
59+
parser.add_argument('--dataset', type=str, default="./.dataset", help="Path to dataset to generate index from")
60+
parser.add_argument('--column', type=str, default="DATA", help="Name of dataset column to index")
61+
parser.add_argument('--num_results', type=int, default=10, help="Number of most suitable matches to show")
62+
parser.add_argument('--model_name', type=str, default='paraphrase-distilroberta-base-v1', help="Specify name of the SentenceTransformer model to use for encoding")
63+
args = parser.parse_args()
64+
65+
model = SentenceTransformer(args.model_name)
66+
67+
if args.positives and not args.negatives:
68+
# Get index, create it from dataset if doesn't exist
69+
index, dataset = create_index(model, args.dataset, args.index, args.column, args.recreate)
70+
71+
# Create vector to represent user's stated positive preference
72+
preference_vector = model.encode([args.positives])
73+
74+
# Find and display most suitable matches for users preferences in the dataset
75+
results = vector_search(preference_vector, index, dataset, args.column, args.num_results)
76+
77+
print("Most Suitable Matches:")
78+
for similarity, id_, data in results:
79+
print(f"Id: {id_}\nSimilarity: {similarity}\n{args.column}: {data[0]}")
80+
81+
elif args.positives and args.negatives:
82+
# Get index, create it from dataset if doesn't exist
83+
index, dataset = create_index(model, args.dataset, args.index, args.column, args.recreate)
84+
85+
# Create vector to represent user's stated preference
86+
positives_vector = np.array(model.encode([args.positives])).astype("float32")
87+
negatives_vector = np.array(model.encode([args.negatives])).astype("float32")
88+
89+
# preference_vector = np.mean([positives_vector, -1 * negatives_vector], axis=0)
90+
preference_vector = np.add(positives_vector, -1 * negatives_vector)
91+
92+
# Find and display most suitable matches for users preferences in the dataset
93+
results = vector_search(preference_vector, index, dataset, args.column, args.num_results)
94+
95+
print("Most Suitable Matches:")
96+
for similarity, id_, data in results:
97+
print(f"Id: {id_}\nSimilarity: {similarity}\n{args.column}: {data[0]}")

0 commit comments

Comments
 (0)