armaniii commited on
Commit
4472504
·
verified ·
1 Parent(s): 2d67e6e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +28 -0
README.md CHANGED
@@ -11,6 +11,34 @@ pipeline_tag: text-classification
11
 
12
  ## Model Details
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  ### Model Description
15
 
16
  <!-- Provide a longer summary of what this model is. -->
 
11
 
12
  ## Model Details
13
 
14
+ ```python
15
+ from transformers import AutoModelForCausalLM, AutoTokenizer
16
+
17
+ device = "cuda" # the device to load the model onto
18
+
19
+ model = AutoModelForCausalLM.from_pretrained("armaniii/llama-argument-classification")
20
+ tokenizer = AutoTokenizer.from_pretrained("armaniii/lllama-argument-classification")
21
+
22
+ model.to(device)
23
+ model.eval()
24
+
25
+ for batch in tqdm.tqdm(data):
26
+ with torch.no_grad():
27
+ input_text = tokenizer(batch, padding=True, truncation=True,max_length=2048,return_tensors="pt").to(device)
28
+ output = model(**input_text)
29
+ logits = output.logits
30
+ predicted_class = torch.argmax(logits, dim=1)
31
+ # Convert logits to a list of predicted labels
32
+ predictions.extend(predicted_class.cpu().tolist())
33
+
34
+ # Get the ground truth labels
35
+ df["predictions"] = predictions
36
+
37
+ num2label = {
38
+ 0:"NoArgument",
39
+ 1:"Argument"
40
+ }
41
+ ```
42
  ### Model Description
43
 
44
  <!-- Provide a longer summary of what this model is. -->