wbrooks commited on
Commit
6b6def4
·
1 Parent(s): 88bbcb9

copied encode function directly into search_embeddings.py

Browse files
Files changed (1) hide show
  1. src/search_embeddings.py +16 -1
src/search_embeddings.py CHANGED
@@ -2,13 +2,28 @@
2
  import numpy as np
3
  import polars as pl
4
 
5
- from src.encode import encode
6
  from sklearn.metrics.pairwise import cosine_similarity
7
 
8
  from transformers import AutoTokenizer, AutoModel
9
 
10
  import glob
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  # define the device where torch calculations take place
14
  my_device = "cpu"
 
2
  import numpy as np
3
  import polars as pl
4
 
 
5
  from sklearn.metrics.pairwise import cosine_similarity
6
 
7
  from transformers import AutoTokenizer, AutoModel
8
 
9
  import glob
10
 
11
+ import torch
12
+
13
+ #
14
+ def encode(sentences, tokenizer, model, device="mps"):
15
+ inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt").to(device = device)
16
+
17
+ with torch.no_grad():
18
+ outputs = model(**inputs)
19
+
20
+ # outputs.last_hidden_state = [batch, tokens, hidden_dim]
21
+ # mean pooling
22
+ embeddings = outputs.last_hidden_state.mean(dim=1)
23
+
24
+ return(embeddings)
25
+
26
+
27
 
28
  # define the device where torch calculations take place
29
  my_device = "cpu"