Spaces:
Sleeping
Sleeping
File size: 12,301 Bytes
ed9ec98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
import os
import gradio as gr
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.utils import embedding_functions
import ollama # Ollamaライブラリをインポート
from langchain_text_splitters import RecursiveCharacterTextSplitter
import uuid # For generating unique IDs for chunks
# --- Ollamaクライアントの初期化 ---
client_ollama = ollama.Client()
OLLAMA_MODEL_NAME = "llama3.2"
# OLLAMA_MODEL_NAME = "llama3:8b-instruct-q4_0"
# --- 埋め込みモデルの初期化 ---
# 重複定義を削除し、1回のみ初期化
embedding_model = SentenceTransformer('pkshatech/GLuCoSE-base-ja') # 日本語対応の埋め込みモデル
# --- ChromaDBのカスタム埋め込み関数 ---
# 重複定義を削除し、1回のみ定義
class SBERTEmbeddingFunction(embedding_functions.EmbeddingFunction):
def __init__(self, model):
self.model = model
def __call__(self, texts):
# sentence-transformersモデルはnumpy配列を返すため、tolist()でPythonリストに変換
return self.model.encode(texts).tolist()
sbert_ef = SBERTEmbeddingFunction(embedding_model)
# --- ChromaDBクライアントとコレクションの初期化 ---
# インメモリモードで動作させ、アプリケーション起動時にコレクションをリセットします。
# グローバル変数としてクライアントを保持
client = chromadb.Client()
collection_name = "pdf_documents_collection"
# アプリケーション起動時にコレクションが存在すれば削除し、新しく作成する
# (インメモリDBはセッションごとにリセットされるため、これは初回起動時のみ意味を持つ)
try:
client.delete_collection(name=collection_name)
print(f"既存のChromaDBコレクション '{collection_name}' を削除しました。")
except Exception as e:
# コレクションが存在しない場合はエラーになるので無視。デバッグ用にメッセージは出力。
print(f"ChromaDBコレクション '{collection_name}' の削除に失敗しました (存在しないか、その他のエラー): {e}")
pass
collection = client.get_or_create_collection(name=collection_name, embedding_function=sbert_ef)
print(f"ChromaDBコレクション '{collection_name}' を初期化しました。")
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000, # チャンクの最大文字数
chunk_overlap=150, # チャンク間のオーバーラップ文字数
length_function=len, # 文字数で長さを計算
separators=["\n\n", "\n", " ", ""] # 分割の優先順位
)
# --- ヘルパー関数 ---
def extract_text_from_pdf(pdf_file_path):
"""PDFファイルからテキストを抽出する"""
try:
reader = PdfReader(pdf_file_path)
text = ""
for page in reader.pages:
page_text = page.extract_text()
if page_text: # ページが空でなければ追加
text += page_text + "\n"
if not text.strip(): # 全ページからテキストが抽出できなかった場合
return "PDFからテキストを抽出できませんでした。画像ベースのPDFかもしれません。"
return text
except Exception as e:
return f"PDFの読み込み中にエラーが発生しました: {e}"
def get_ollama_response(query, context, source_code_to_check):
"""Ollamaモデルを使用して質問に回答する"""
messages = [
{"role": "system", "content": "あなたは提供されたコンテキスト(ソースコードチェックリスト)とレビュー対象のソースコードに基づいて、ソースコードをチェックし、その結果を返す有益なアシスタントです。チェックリストの項目ごとにソースコードを評価し、具体的な指摘と改善案を提示してください。コンテキストに情報がない場合は、「提供された情報からは回答できません。」と答えてください。"},
{"role": "user", "content": f"ソースコードチェックリスト:\n{context}\n\nレビュー対象のソースコード:\n```\n{source_code_to_check}\n```\n\n質問: {query}\n\nチェック結果:"}
]
try:
print(f"Ollamaモデル '{OLLAMA_MODEL_NAME}' にリクエストを送信中...")
response = client_ollama.chat(
model=OLLAMA_MODEL_NAME,
messages=messages,
options={
"temperature": 0.5,
"num_predict": 2000 # 回答の最大トークン数
}
)
print(f"Ollamaからの生応答: {response}") # デバッグ用にOllamaの生レスポンスを出力
if 'message' in response and 'content' in response['message']:
return response['message']['content'].strip()
else:
print(f"Ollamaからの応答形式が不正です: {response}")
return "Ollamaモデルからの応答形式が不正です。詳細をコンソールログで確認してください。"
except ollama.ResponseError as e: # Ollama固有のエラーを捕捉
print(f"Ollama APIエラーが発生しました: {e}")
return f"Ollamaモデルの呼び出し中にAPIエラーが発生しました: {e}\nOllamaサーバーが起動しているか、モデル '{OLLAMA_MODEL_NAME}' がインストールされているか確認してください。"
except Exception as e: # その他の予期せぬエラー
print(f"Ollamaモデルの呼び出し中に予期せぬエラーが発生しました: {e}")
return f"Ollamaモデルの呼び出し中に予期せぬエラーが発生しました: {e}"
def upload_pdf_and_process(pdf_files):
"""複数のPDFファイルをアップロードし、テキストを抽出し、ChromaDBに登録する"""
if not pdf_files:
return "PDFファイルがアップロードされていません。", gr.update(interactive=False), gr.update(interactive=False)
processed_files_count = 0
total_chunks_added = 0
all_status_messages = []
for pdf_file in pdf_files:
try:
pdf_path = pdf_file.name
file_name = os.path.basename(pdf_path)
all_status_messages.append(f"PDFファイル '{file_name}' を処理中...")
print(f"Processing PDF: {file_name}")
# 1. PDFからテキストを抽出
raw_text = extract_text_from_pdf(pdf_path)
if "エラー" in raw_text or "抽出できませんでした" in raw_text:
all_status_messages.append(raw_text)
print(f"Error extracting text from {file_name}: {raw_text}")
continue # 次のファイルへ
# 2. テキストをチャンクに分割
chunks = text_splitter.split_text(raw_text)
if not chunks:
all_status_messages.append(f"'{file_name}' から有効なテキストチャンクを抽出できませんでした。")
print(f"No valid chunks extracted from {file_name}.")
continue # 次のファイルへ
# 3. チャンクをChromaDBに登録
documents = chunks
metadatas = [{"source": file_name, "chunk_index": i} for i in range(len(chunks))]
ids = [str(uuid.uuid4()) for _ in range(len(chunks))]
collection.add(
documents=documents,
metadatas=metadatas,
ids=ids
)
processed_files_count += 1
total_chunks_added += len(chunks)
all_status_messages.append(f"PDFファイル '{file_name}' の処理が完了しました。{len(chunks)}個のチャンクがデータベースに登録されました。")
print(f"Finished processing {file_name}. Added {len(chunks)} chunks.")
except Exception as e:
all_status_messages.append(f"PDFファイル '{os.path.basename(pdf_file.name)}' 処理中に予期せぬエラーが発生しました: {e}")
print(f"Unexpected error during processing {os.path.basename(pdf_file.name)}: {e}")
continue # 次のファイルへ
final_status_message = f"{processed_files_count}個のPDFファイルの処理が完了しました。合計{total_chunks_added}個のチャンクがデータベースに登録されました。質問とソースコードを入力してください。\n\n" + "\n".join(all_status_messages)
return final_status_message, gr.update(interactive=True), gr.update(interactive=True)
def answer_question(question, source_code):
"""ChromaDBから関連情報を取得し、Ollamaモデルで質問に回答する"""
if not question and not source_code: # 質問またはソースコードのいずれかがあればOK
return "質問またはレビュー対象のソースコードを入力してください。", ""
if collection.count() == 0:
return "PDFがまだアップロードされていないか、処理されていません。まずPDFをアップロードしてください。", ""
try:
print(f"Searching ChromaDB for question: {question}") # デバッグ出力
results = collection.query(
query_texts=[question],
n_results=8 # 上位8つの関連チャンクを取得
)
context_chunks = results['documents'][0] if results['documents'] else []
if not context_chunks:
print("No relevant context chunks found in ChromaDB.")
return "関連する情報が見つかりませんでした。質問を明確にするか、別のPDFを試してください。", ""
context = "\n\n".join(context_chunks)
print(f"Retrieved context (first 500 chars):\n{context[:500]}...") # デバッグ用にコンテキストの一部を出力
answer = get_ollama_response(question, context, source_code) # 関数名を変更
return answer, context
except Exception as e:
print(f"質問応答中に予期せぬエラーが発生しました: {e}")
return f"質問応答中に予期せぬエラーが発生しました: {e}", ""
# --- Gradio UIの構築 ---
with gr.Blocks() as gradioUI:
gr.Markdown(
f"""
# PDF Q&A with Local LLM (Ollama: {OLLAMA_MODEL_NAME}) and Vector Database
PDFファイルとしてソースコードチェックリストをアップロードし、レビューしたいソースコードを入力してください。
**複数のPDFファイルを同時にアップロードできます。**
ローカルのOllama ({OLLAMA_MODEL_NAME}) を使用しています。
"""
)
with gr.Row():
with gr.Column():
pdf_input = gr.File(label="PDFドキュメントをアップロード", file_types=[".pdf"], file_count="multiple")
upload_status = gr.Textbox(label="ステータス", interactive=False, value="PDFをアップロードしてください。", lines=5)
with gr.Column():
source_code_input = gr.Code(
label="レビュー対象のソースコード (ここにソースコードを貼り付けてください)",
value="",
language="python",
interactive=True,
lines=15
)
question_input = gr.Textbox(label="レビュー指示(例: セキュリティの観点からレビュー)", placeholder="特定の観点からのレビュー指示を入力してください(任意)。", interactive=False)
review_button = gr.Button("レビュー開始")
answer_output = gr.Markdown(label="レビュー結果")
retrieved_context_output = gr.Textbox(label="取得されたチェックリスト項目", interactive=False, lines=10)
pdf_input.upload(
upload_pdf_and_process,
inputs=[pdf_input],
outputs=[upload_status, question_input, source_code_input]
)
review_button.click(
answer_question,
inputs=[question_input, source_code_input],
outputs=[answer_output, retrieved_context_output]
)
gradioUI.launch(server_name="0.0.0.0", server_port=7860)
|