1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
|
"""
CUDA_VISIBLE_DEVICES=0,1
"""
import logging
import torch
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from PIL import Image
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
def load_image(image_file):
"""
"""
if image_file.startswith('http://') or image_file.startswith('https://'):
response = requests.get(image_file)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file).convert('RGB')
return image
def load_model(hf_model_name='liuhaotian/llava-v1.6-34b'):
"""
"""
disable_torch_init()
#model_path = 'liuhaotian/llava-v1.5-7b'
#model_path = 'liuhaotian/llava-v1.5-13b'
model_path = hf_model_name #'liuhaotian/llava-v1.6-34b'
model_name = get_model_name_from_path(model_path) # 'llava-v1.5-7b'
model_base = None
load_8bit = False
load_4bit = True
device = 'cuda'
tokenizer, model, image_processor, context_len = load_pretrained_model(
model_path, model_base, model_name, load_8bit, load_4bit, device=device)
if 'llama-2' in model_name.lower():
conv_mode = "llava_llama_2"
elif "v1" in model_name.lower():
conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
conv_mode = "mpt"
else:
conv_mode = "llava_v0"
return tokenizer, model, image_processor, conv_mode
def build_prompt(image_message=None, system_message=None):
# https://docs.google.com/document/d/1CflrE1mNU-rz_j7H2Au580JA9JYGKckATMNS5uQrG2w/edit#heading=h.hrdqg3a8zs4
original_system_message = "A chat between a curious human and an artificial intelligence assistant. \
The assistant gives helpful, detailed, and polite answers to the human's questions"
sys_message = system_message if system_message is not None else original_system_message
if image_message is not None:
return f'{sys_message} USER: {image_message} ASSISTANT:'
else:
return f'{sys_message} USER: <image> Describe the image in details. What are the primary object in this image? Does this image have a identifiable landmark or tag? ASSISTANT:'
def generate(image_file:str,
user_message:str,
system_message:str,
tokenizer, model,
image_processor, conv_mode, temperature:float=0., max_new_tokens:int=512):
"""
@Args:
image_file = "https://llava-vl.github.io/static/images/view.jpg"
"""
conv = conv_templates[conv_mode].copy()
#if "mpt" in model_name.lower():
# roles = ('user', 'assistant')
#else:
# roles = conv.roles
roles = conv.roles
# logging.info(f'Roles: {roles}') # ('USER', 'ASSISTANT')
image = load_image(image_file)
image_tensor = process_images([image], image_processor, model.config)
if type(image_tensor) is list:
image_tensor = [image.to(model.device, dtype=torch.float16) for image in image_tensor]
else:
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
if image is not None:
# first message
if model.config.mm_use_im_start_end:
image_message = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + user_message
else:
image_message = DEFAULT_IMAGE_TOKEN + '\n' + user_message
conv.append_message(conv.roles[0], image_message) # USER: <image> {question}
image = None
else:
# later messages
conv.append_message(conv.roles[0], user_message)
# message = "<image> prompt"
conv.append_message(conv.roles[1], None) # ASSISTANT:
#prompt = conv.get_prompt()
"""
A chat between a curious human and an artificial intelligence assistant.
The assistant gives helpful, detailed, and polite answers to the human's questions.
USER: <image> {question} ASSISTANT:
"""
prompt = build_prompt(image_message, system_message) # message
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True if temperature > 0 else False,
temperature=temperature,
max_new_tokens=max_new_tokens,
streamer=streamer,
use_cache=True,
stopping_criteria=[stopping_criteria])
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
resp = outputs.rstrip('</s>')
logging.info(f'Q: {prompt}\nA:{resp}')
return resp
|