DaddyDaniel commited on
Commit
8557bbe
·
1 Parent(s): 979c542

Fix inference and Dockerfile

Browse files

- Fix model generation
- Added proper output rendering
- Added download buttons

Files changed (9) hide show
  1. .dockerignore +1 -0
  2. Dockerfile +34 -0
  3. NLP_Group_logo.png +0 -0
  4. main.py +6 -1
  5. main_page.py +6 -0
  6. qwen2_inference.py +62 -12
  7. requirements.txt +11 -2
  8. sketch2diagram.py +39 -11
  9. util.py +26 -0
.dockerignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .venv
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
2
+
3
+ # Set environment variables to reduce interactive prompts
4
+ ENV DEBIAN_FRONTEND=noninteractive
5
+
6
+ # Install dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ python3.10 \
9
+ python3-pip \
10
+ git \
11
+ texlive-latex-base \
12
+ texlive-latex-extra \
13
+ texlive-fonts-recommended \
14
+ texlive-latex-recommended \
15
+ latexmk \
16
+ poppler-utils \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Copy the files
20
+ WORKDIR /app
21
+ COPY requirements.txt .
22
+
23
+ RUN pip install --upgrade pip \
24
+ && pip install --no-cache-dir -r requirements.txt
25
+
26
+ ENV PATH="/root/.local/bin:$PATH"
27
+ ENV STREAMLIT_WATCHER_TYPE none
28
+
29
+ RUN pip install --no-cache-dir https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.6/flash_attn-2.6.3+cu124torch2.6-cp310-cp310-linux_x86_64.whl
30
+
31
+ COPY . .
32
+
33
+ # Default command
34
+ ENTRYPOINT ["streamlit", "run", "main.py"]
NLP_Group_logo.png ADDED
main.py CHANGED
@@ -1,6 +1,11 @@
 
 
1
  import streamlit as st
 
2
 
3
- st.logo("NLP_Group_logo.svg", size="large")
 
 
4
  main_page = st.Page("main_page.py", title="Main Page", icon="🏠")
5
  sketch2diagram_page = st.Page("sketch2diagram.py", title="Sketch2Diagram", icon="🖼️")
6
  # Add pages to the main page
 
1
+ import os
2
+
3
  import streamlit as st
4
+ from PIL import Image
5
 
6
+ logo_path = os.path.join(os.path.dirname(__file__), "NLP_Group_logo.png")
7
+ logo = Image.open(logo_path)
8
+ st.logo(logo, size="large")
9
  main_page = st.Page("main_page.py", title="Main Page", icon="🏠")
10
  sketch2diagram_page = st.Page("sketch2diagram.py", title="Sketch2Diagram", icon="🖼️")
11
  # Add pages to the main page
main_page.py CHANGED
@@ -3,3 +3,9 @@ import streamlit as st
3
  st.title("Tohoku NLP Group - Language and Information Science Laboratory ")
4
  st.write("Welcome to the Language and Information Science Laboratory!")
5
  st.write("We are working on various projects and research focused on Visual Language Models.")
 
 
 
 
 
 
 
3
  st.title("Tohoku NLP Group - Language and Information Science Laboratory ")
4
  st.write("Welcome to the Language and Information Science Laboratory!")
5
  st.write("We are working on various projects and research focused on Visual Language Models.")
6
+
7
+
8
+ # Link to sketch2diagram page
9
+ st.subheader("You can check out our models and demos here:")
10
+
11
+ st.write("[Sketch2Diagram](sketch2diagram) - A model that generates TikZ code from sketches.")
qwen2_inference.py CHANGED
@@ -1,21 +1,47 @@
 
 
1
  import streamlit as st
2
  import torch
3
  from PIL import Image
 
 
4
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
5
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  # Inference steps taken from https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
8
 
9
- @st.cache_resource
10
  def get_model(model_path):
11
  try:
12
  with st.spinner(f"Loading model {model_path}"):
13
- device = "cuda" if torch.cuda.is_available() else "cpu"
14
  # Load the model here
15
  model_import = Qwen2VLForConditionalGeneration.from_pretrained(
16
- model_path, torch_dtype="auto", device_map=device
 
 
17
  )
18
- processor_import = AutoProcessor.from_pretrained(model_path)
 
 
 
 
 
 
 
 
 
 
19
 
20
  return model_import, processor_import
21
  except Exception as e:
@@ -27,27 +53,43 @@ def run_inference(input_file, model_path, args):
27
  model, processor = get_model(model_path)
28
  if model is None or processor is None:
29
  return "Error loading model."
 
 
 
 
30
  image = Image.open(input_file)
31
  conversation = [
32
  {
33
  "role": "user",
34
  "content": [
35
- {"type": "image"},
36
  {"type": "text", "text": "Please generate TikZ code to draw the diagram of the given image."}
37
  ],
38
  }
39
  ]
40
- text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
41
- inputs = processor(image, text_prompt, return_tensors="pt").to("cuda")
 
 
 
 
 
 
 
 
 
 
42
 
43
  output_ids = model.generate(**inputs,
44
- max_new_tokens=args.max_length,
45
  do_sample=True,
46
- top_p=args.top_p,
47
- top_k=args.top_k,
 
48
  num_return_sequences=1,
49
- temperature=args.temperature
50
- )
 
51
  generated_ids = [
52
  output_ids[len(input_ids):]
53
  for input_ids, output_ids in zip(inputs.input_ids, output_ids)
@@ -55,4 +97,12 @@ def run_inference(input_file, model_path, args):
55
  output_text = processor.batch_decode(
56
  generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
57
  )
 
 
 
 
 
 
 
 
58
  return output_text
 
1
+ import os
2
+
3
  import streamlit as st
4
  import torch
5
  from PIL import Image
6
+ from dotenv import load_dotenv
7
+ from qwen_vl_utils import process_vision_info
8
  from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
9
 
10
+ load_dotenv()
11
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
12
+
13
+
14
+ def print_gpu_memory(label, memory_allocated, memory_reserved):
15
+ if torch.cuda.is_available():
16
+ print("-----------------------------------")
17
+ print(f"{label} GPU Memory Usage:")
18
+ print(f"Allocated: {memory_allocated / 1024 ** 2:.2f} MB")
19
+ print(f"Cached: {memory_reserved / 1024 ** 2:.2f} MB")
20
+
21
 
22
  # Inference steps taken from https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
23
 
24
+ # @st.cache_resource
25
  def get_model(model_path):
26
  try:
27
  with st.spinner(f"Loading model {model_path}"):
 
28
  # Load the model here
29
  model_import = Qwen2VLForConditionalGeneration.from_pretrained(
30
+ model_path, torch_dtype="auto", device_map="auto",
31
+ attn_implementation="flash_attention_2",
32
+ token=HUGGINGFACE_TOKEN,
33
  )
34
+ model_import = model_import.to("cuda")
35
+ size = {
36
+ "shortest_edge": 224,
37
+ "longest_edge": 1024,
38
+ }
39
+ processor_import = AutoProcessor.from_pretrained("itsumi-st/imgtikz_qwen2vl",
40
+ size=size,
41
+ min_pixels=256 * 256,
42
+ max_pixels=1024 * 1024,
43
+ token=HUGGINGFACE_TOKEN)
44
+ processor_import.tokenizer.padding_side = 'left'
45
 
46
  return model_import, processor_import
47
  except Exception as e:
 
53
  model, processor = get_model(model_path)
54
  if model is None or processor is None:
55
  return "Error loading model."
56
+
57
+ # GPU Memory after model loading:
58
+ after_model_dump = (torch.cuda.memory_allocated(), torch.cuda.memory_reserved())
59
+
60
  image = Image.open(input_file)
61
  conversation = [
62
  {
63
  "role": "user",
64
  "content": [
65
+ {"type": "image", "image": image},
66
  {"type": "text", "text": "Please generate TikZ code to draw the diagram of the given image."}
67
  ],
68
  }
69
  ]
70
+ text_prompt = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
71
+ image_input, video_inputs = process_vision_info(conversation)
72
+ inputs = processor(
73
+ text=[text_prompt],
74
+ images=image_input,
75
+ videos=video_inputs,
76
+ padding=True,
77
+ return_tensors="pt",
78
+ ).to("cuda")
79
+
80
+ # GPU Memory after input processing
81
+ after_input_dump = (torch.cuda.memory_allocated(), torch.cuda.memory_reserved())
82
 
83
  output_ids = model.generate(**inputs,
84
+ max_new_tokens=args['max_length'],
85
  do_sample=True,
86
+ top_p=args['top_p'],
87
+ top_k=args['top_k'],
88
+ use_cache=True,
89
  num_return_sequences=1,
90
+ pad_token_id=processor.tokenizer.pad_token_id,
91
+ temperature=args['temperature']
92
+ )
93
  generated_ids = [
94
  output_ids[len(input_ids):]
95
  for input_ids, output_ids in zip(inputs.input_ids, output_ids)
 
97
  output_text = processor.batch_decode(
98
  generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
99
  )
100
+
101
+ # GPU Memory after generation
102
+ after_gen_dump = (torch.cuda.memory_allocated(), torch.cuda.memory_reserved())
103
+
104
+ print_gpu_memory("Before Model", after_model_dump[0], after_model_dump[1])
105
+ print_gpu_memory("After Input", after_input_dump[0], after_input_dump[1])
106
+ print_gpu_memory("After Generation", after_gen_dump[0], after_gen_dump[1])
107
+
108
  return output_text
requirements.txt CHANGED
@@ -1,3 +1,12 @@
1
  streamlit~=1.43.2
2
- transformers~=4.50.0
3
- pillow~=11.1.0
 
 
 
 
 
 
 
 
 
 
1
  streamlit~=1.43.2
2
+ torch==2.6.0
3
+ torchvision==0.21.0
4
+ torchaudio
5
+ transformers==4.48.2
6
+ qwen-vl-utils==0.0.10
7
+ packaging
8
+ accelerate==1.0.1
9
+ requests
10
+ pillow
11
+ python-dotenv
12
+ pdf2image
sketch2diagram.py CHANGED
@@ -1,6 +1,8 @@
1
  import streamlit as st
 
2
 
3
  from qwen2_inference import run_inference
 
4
 
5
  args = {}
6
 
@@ -8,12 +10,16 @@ args = {}
8
  st.sidebar.title("Model Configuration")
9
  model_name = st.sidebar.selectbox("Model Name", ['Itsumi-st/Imgtikz_Qwen2vl', 'Qwen/Qwen2-VL-7B-Instruct'])
10
  args['inference_strat'] = st.sidebar.selectbox("Inference Strategy", ["Iterative", "Multi-candidate"],
11
- help="Choose the inference strategy for the model. Iterative generates one candidate at a time until an output compiles, while Multi-candidate generates multiple candidates in parallel.")
12
- args['max_length'] = st.sidebar.slider("Max Length", 1, 5096, 2048, help="Maximum length of the generated output. The model will generate text up to this length.")
 
13
  args['seed'] = st.sidebar.number_input("Seed", min_value=0, value=42, step=1)
14
- args['top_p'] = st.sidebar.slider("Top P", 0.0, 1.0, 1.0, step=0.01, help="Top P sampling parameter. The model will sample from the top P percentage of the probability distribution.")
15
- args['temperature'] = st.sidebar.slider("Top P", 0.0, 1.0, 0.6, step=0.01, help="Temperature parameter for sampling. Higher values result in more random outputs.")
16
- args['top_k'] = st.sidebar.slider("Top K", 0, 100, 50, step=1, help="Top K sampling parameter. The model will sample from the top K tokens with the highest probabilities.")
 
 
 
17
 
18
  # Introduction Section
19
  st.title("Sketch2Diagram")
@@ -22,7 +28,6 @@ st.write("This is a runnable demo of ImgTikZ model introduced in the Sketch2Diag
22
  st.write("Please refer to the [original paper](https://openreview.net/pdf?id=KvaDHPhhir) for more details.")
23
  st.write("The model is trained to convert sketches into TikZ code, which can be used to generate vectorized diagrams.")
24
 
25
-
26
  # User Input Section
27
  st.subheader("Upload your sketch")
28
 
@@ -32,10 +37,10 @@ input_method = st.selectbox("Input Method", ["Upload", "Camera"],
32
  input_file = None
33
  if input_method == "Camera":
34
  input_file = st.camera_input("Take a picture of your sketch")
35
- # Implement camera input functionality here
36
  else:
37
  input_file = st.file_uploader("Upload an image of your sketch", type=["png", "jpg", "jpeg"])
38
-
39
  generate_command = None
40
  # Display the uploaded image
41
  if input_file is not None:
@@ -45,6 +50,29 @@ if input_file is not None:
45
  # Run model inference
46
  if generate_command:
47
  with st.spinner("Generating TikZ code..."):
48
- output = run_inference(input_file, model_name, args)
49
- st.success("TikZ code generated successfully!")
50
- st.code(output, language='latex')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from pdf2image import convert_from_path
3
 
4
  from qwen2_inference import run_inference
5
+ from util import compile_tikz_to_pdf
6
 
7
  args = {}
8
 
 
10
  st.sidebar.title("Model Configuration")
11
  model_name = st.sidebar.selectbox("Model Name", ['Itsumi-st/Imgtikz_Qwen2vl', 'Qwen/Qwen2-VL-7B-Instruct'])
12
  args['inference_strat'] = st.sidebar.selectbox("Inference Strategy", ["Iterative", "Multi-candidate"],
13
+ help="Choose the inference strategy for the model. Iterative generates one candidate at a time until an output compiles, while Multi-candidate generates multiple candidates in parallel.")
14
+ args['max_length'] = st.sidebar.slider("Max Length", 1, 5096, 2048,
15
+ help="Maximum length of the generated output. The model will generate text up to this length.")
16
  args['seed'] = st.sidebar.number_input("Seed", min_value=0, value=42, step=1)
17
+ args['temperature'] = st.sidebar.slider("Temperature", 0.0, 1.0, 0.6, step=0.01,
18
+ help="Temperature parameter for sampling. Higher values result in more random outputs.")
19
+ args['top_p'] = st.sidebar.slider("Top P", 0.0, 1.0, 1.0, step=0.01,
20
+ help="Top P sampling parameter. The model will sample from the top P percentage of the probability distribution.")
21
+ args['top_k'] = st.sidebar.slider("Top K", 0, 100, 50, step=1,
22
+ help="Top K sampling parameter. The model will sample from the top K tokens with the highest probabilities.")
23
 
24
  # Introduction Section
25
  st.title("Sketch2Diagram")
 
28
  st.write("Please refer to the [original paper](https://openreview.net/pdf?id=KvaDHPhhir) for more details.")
29
  st.write("The model is trained to convert sketches into TikZ code, which can be used to generate vectorized diagrams.")
30
 
 
31
  # User Input Section
32
  st.subheader("Upload your sketch")
33
 
 
37
  input_file = None
38
  if input_method == "Camera":
39
  input_file = st.camera_input("Take a picture of your sketch")
40
+ # todo: Implement camera input functionality here
41
  else:
42
  input_file = st.file_uploader("Upload an image of your sketch", type=["png", "jpg", "jpeg"])
43
+ st.write(args)
44
  generate_command = None
45
  # Display the uploaded image
46
  if input_file is not None:
 
50
  # Run model inference
51
  if generate_command:
52
  with st.spinner("Generating TikZ code..."):
53
+ output = run_inference(input_file, model_name, args)[0]
54
+ pdf_file_path = compile_tikz_to_pdf(output)
55
+ if output and pdf_file_path:
56
+ st.success("TikZ code generated successfully!")
57
+ st.code(output, language='latex')
58
+
59
+ st.download_button(
60
+ label="Download LaTeX Code",
61
+ data=output,
62
+ file_name="output.tex",
63
+ mime="text/plain"
64
+ )
65
+
66
+ # st.image(pdf_file_path, caption="Generated Diagram", use_column_width=True)
67
+ with open(pdf_file_path, "rb") as f:
68
+ st.download_button(
69
+ label="Download PDF",
70
+ data=f.read(), # ✅ this is the binary content
71
+ file_name="output.pdf",
72
+ mime="application/pdf"
73
+ )
74
+
75
+ images = convert_from_path(pdf_file_path)
76
+ st.image(images[0], caption="Generated Diagram", use_column_width=True)
77
+ else:
78
+ st.error("Failed to generate TikZ code.")
util.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import tempfile
4
+
5
+
6
+ def compile_tikz_to_pdf(tikz_code):
7
+ temp_dir = tempfile.mkdtemp()
8
+
9
+ tex_path = os.path.join(temp_dir, "output.tex")
10
+ pdf_path = os.path.join(temp_dir, "output.pdf")
11
+
12
+ with open(tex_path, "w") as f:
13
+ f.write(tikz_code)
14
+
15
+ try:
16
+ subprocess.run(
17
+ ["pdflatex", "-interaction=nonstopmode", tex_path],
18
+ cwd=temp_dir,
19
+ stdout=subprocess.PIPE,
20
+ stderr=subprocess.PIPE,
21
+ check=True,
22
+ )
23
+ return pdf_path
24
+ except subprocess.CalledProcessError as e:
25
+ print("PDF compilation failed:", e)
26
+ return None