Update README.md
Browse files
README.md
CHANGED
|
@@ -91,13 +91,13 @@ from models import CLIPVisionTower
|
|
| 91 |
DEVICE = "cuda:0"
|
| 92 |
PROMPT = "This is a dialog with AI assistant.\n"
|
| 93 |
|
| 94 |
-
tokenizer = AutoTokenizer.from_pretrained("AIRI-Institute/OmniFusion", subfolder="OmniMistral-tokenizer", use_fast=False)
|
| 95 |
-
model = AutoModelForCausalLM.from_pretrained("AIRI-Institute/OmniFusion", subfolder="OmniMistral-model", torch_dtype=torch.bfloat16, device_map=DEVICE)
|
| 96 |
|
| 97 |
-
hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="projection", local_dir='./')
|
| 98 |
-
hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="special_embeddings.pt", local_dir='./')
|
| 99 |
-
projection = torch.load("projection", map_location=DEVICE)
|
| 100 |
-
special_embs = torch.load("special_embeddings.pt", map_location=DEVICE)
|
| 101 |
|
| 102 |
clip = CLIPVisionTower("openai/clip-vit-large-patch14-336")
|
| 103 |
clip.load_model()
|
|
@@ -149,7 +149,7 @@ def gen_answer(model, tokenizer, clip, projection, query, special_embs, image=No
|
|
| 149 |
return generated_texts
|
| 150 |
|
| 151 |
img_url = "https://i.pinimg.com/originals/32/c7/81/32c78115cb47fd4825e6907a83b7afff.jpg"
|
| 152 |
-
question = "
|
| 153 |
img = Image.open(urlopen(img_url))
|
| 154 |
|
| 155 |
answer = gen_answer(
|
|
|
|
| 91 |
DEVICE = "cuda:0"
|
| 92 |
PROMPT = "This is a dialog with AI assistant.\n"
|
| 93 |
|
| 94 |
+
tokenizer = AutoTokenizer.from_pretrained("AIRI-Institute/OmniFusion", subfolder="OmniMistral-v1_1/tokenizer", use_fast=False)
|
| 95 |
+
model = AutoModelForCausalLM.from_pretrained("AIRI-Institute/OmniFusion", subfolder="OmniMistral-v1_1/tuned-model", torch_dtype=torch.bfloat16, device_map=DEVICE)
|
| 96 |
|
| 97 |
+
hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="OmniMistral-v1_1/projection.pt", local_dir='./')
|
| 98 |
+
hf_hub_download(repo_id="AIRI-Institute/OmniFusion", filename="OmniMistral-v1_1/special_embeddings.pt", local_dir='./')
|
| 99 |
+
projection = torch.load("OmniMistral-v1_1/projection.pt", map_location=DEVICE)
|
| 100 |
+
special_embs = torch.load("OmniMistral-v1_1/special_embeddings.pt", map_location=DEVICE)
|
| 101 |
|
| 102 |
clip = CLIPVisionTower("openai/clip-vit-large-patch14-336")
|
| 103 |
clip.load_model()
|
|
|
|
| 149 |
return generated_texts
|
| 150 |
|
| 151 |
img_url = "https://i.pinimg.com/originals/32/c7/81/32c78115cb47fd4825e6907a83b7afff.jpg"
|
| 152 |
+
question = "What is the sky color on this image?"
|
| 153 |
img = Image.open(urlopen(img_url))
|
| 154 |
|
| 155 |
answer = gen_answer(
|