sirochild commited on
Commit
733f37f
·
verified ·
1 Parent(s): c51798e

Upload 2 files

Browse files
Files changed (2) hide show
  1. main_app.py +27 -25
  2. session_api_server.py +11 -64
main_app.py CHANGED
@@ -140,35 +140,37 @@ if 'server_started' not in st.session_state:
140
  else:
141
  logger.warning("⚠️ セッション管理サーバー起動失敗 - フォールバックモードで動作")
142
 
143
- query_params = st.experimental_get_query_params()
144
- temp_token = query_params.get("temp_token", [None])[0]
145
 
 
 
 
146
  def get_current_user_id():
147
- if temp_token:
148
- try:
149
- user_info = requests.get(f"http://localhost:8000/user?temp_token={temp_token}").json()
150
- user_id = user_info.get("id")
151
- is_hf_user = True
152
- except Exception as e:
153
- st.warning(f"HF API エラー: {e}. 匿名モードを使用します。")
154
- user_id = str(uuid.uuid4())
155
- is_hf_user = False
156
- else:
157
- user_id = str(uuid.uuid4())
158
- is_hf_user = False
159
- return user_id, is_hf_user
160
-
161
- user_id, is_hf_user = get_current_user_id()
162
- st.write(f"ユーザー ID: {user_id} (HF ログイン: {is_hf_user})")
163
 
164
- if not is_hf_user:
165
- if st.button("HuggingFaceでログイン"):
166
- # FastAPI /login から auth_url を取得
167
- res = requests.get("http://localhost:8000/login").json()
168
- auth_url = res.get("auth_url")
169
- # ユーザーを OAuth ページにリダイレクト
170
- st.markdown(f'<meta http-equiv="refresh" content="0; url={auth_url}">', unsafe_allow_html=True)
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  # --- 必要なモジュールのインポート ---
174
 
 
140
  else:
141
  logger.warning("⚠️ セッション管理サーバー起動失敗 - フォールバックモードで動作")
142
 
 
 
143
 
144
+ SPACE_URL = "https://huggingface.co/spaces/sirochild/mari-chat-3"
145
+ LOGIN_URL = f"{SPACE_URL}/oauth/huggingface/login"
146
+ ME_URL = f"{SPACE_URL}/me"
147
  def get_current_user_id():
148
+ # --- セッションステートにユーザー情報を保存 ---
149
+ if "user" not in st.session_state:
150
+ st.session_state.user = None
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ # --- ログインしてなければボタンを表示 ---
153
+ if st.session_state.user is None:
154
+ st.markdown(f"[👉 Hugging Faceでログイン]({LOGIN_URL})")
 
 
 
 
155
 
156
+ else:
157
+ st.success(f"ログイン中: {st.session_state.user['preferred_username']}")
158
+ st.json(st.session_state.user)
159
+
160
+ # --- クエリパラメータに code があれば認証処理 ---
161
+ query_params = st.query_params
162
+ if "code" in query_params:
163
+ code = query_params["code"]
164
+ # /me にアクセスしてユーザー情報を取得
165
+ try:
166
+ resp = requests.get(ME_URL, cookies={"hf_oauth_code": code})
167
+ if resp.status_code == 200:
168
+ st.session_state.user = resp.json()
169
+ st.experimental_rerun()
170
+ else:
171
+ st.error("認証失敗しました")
172
+ except Exception as e:
173
+ st.error(f"エラー: {e}")
174
 
175
  # --- 必要なモジュールのインポート ---
176
 
session_api_server.py CHANGED
@@ -12,6 +12,7 @@ import os
12
  import time
13
  import secrets
14
  import base64
 
15
  from datetime import datetime, timedelta
16
  from typing import Optional, Dict, Any
17
  import logging
@@ -36,72 +37,18 @@ app.add_middleware(
36
  allow_headers=["*"],
37
  )
38
 
39
- HF_CLIENT_ID = os.getenv("OAUTH_CLIENT_ID")
40
- HF_CLIENT_SECRET = os.getenv("OAUTH_CLIENT_SECRET")
41
- HF_REDIRECT_URI = os.getenv("OAUTH_REDIRECT_URI", "https://sirochild-mari-chat-3.hf.space/callback")
42
- HF_SCOPES = os.getenv("HF_OAUTH_SCOPES", "openid profile email")
43
 
44
- # 一時トークンとアクセストークンのマッピング
45
- TEMP_STORE = {} # temp_token -> hf_access_token
46
- STATE_STORE = {} # state の管理
47
-
48
- @app.get("/login")
49
- def login():
50
- state = secrets.token_urlsafe(16)
51
- STATE_STORE[state] = True
52
- auth_url = (
53
- f"https://huggingface.co/oauth/authorize"
54
- f"?client_id={HF_CLIENT_ID}"
55
- f"&redirect_uri={HF_REDIRECT_URI}"
56
- f"&response_type=code"
57
- f"&scope={HF_SCOPES}"
58
- f"&state={state}"
59
- )
60
- return RedirectResponse(auth_url)
61
-
62
- @app.get("/callback")
63
- def callback(code: str, state: str):
64
- if state not in STATE_STORE:
65
- return JSONResponse({"error": "Invalid state"}, status_code=400)
66
- del STATE_STORE[state]
67
-
68
- # Basic 認証ヘッダー作成
69
- auth_header = base64.b64encode(f"{HF_CLIENT_ID}:{HF_CLIENT_SECRET}".encode()).decode()
70
-
71
- token_res = requests.post(
72
- "https://huggingface.co/oauth/token",
73
- data={
74
- "grant_type": "authorization_code",
75
- "code": code,
76
- "redirect_uri": HF_REDIRECT_URI,
77
- "client_id": HF_CLIENT_ID,
78
- "client_secret": HF_CLIENT_SECRET,
79
  }
80
- ).json()
81
-
82
- access_token = token_res.get("access_token")
83
- if not access_token:
84
- return JSONResponse({"error": "Failed to get access token"}, status_code=400)
85
-
86
- # 一時トークン発行
87
- temp_token = secrets.token_urlsafe(16)
88
- TEMP_STORE[temp_token] = access_token
89
-
90
- # Streamlit に temp_token を渡す
91
- return RedirectResponse(f"https://sirochild-mari-chat-3.hf.space/?temp_token={temp_token}")
92
-
93
- @app.get("/user")
94
- def get_user(temp_token: str):
95
- access_token = TEMP_STORE.get(temp_token)
96
- if not access_token:
97
- return JSONResponse({"error": "Invalid or expired token"}, status_code=400)
98
-
99
- user_res = requests.get(
100
- "https://huggingface.co/oauth/userinfo",
101
- headers={"Authorization": f"Bearer {access_token}"}
102
- ).json()
103
-
104
- return JSONResponse(user_res)
105
 
106
 
107
  class SessionManager:
 
12
  import time
13
  import secrets
14
  import base64
15
+ from huggingface_hub import attach_huggingface_oauth, parse_huggingface_oauth
16
  from datetime import datetime, timedelta
17
  from typing import Optional, Dict, Any
18
  import logging
 
37
  allow_headers=["*"],
38
  )
39
 
40
+ # OAuthルートを追加
41
+ attach_huggingface_oauth(app)
 
 
42
 
43
+ @app.get("/me")
44
+ def me(request: Request):
45
+ oauth_info = parse_huggingface_oauth(request)
46
+ if oauth_info is None:
47
+ return {"msg": "Not logged in"}
48
+ return {
49
+ "username": oauth_info.user_info.preferred_username,
50
+ "email": oauth_info.user_info.email,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
 
54
  class SessionManager: