Commit
·
d9ea03e
1
Parent(s):
e7c8363
Update app.py
Browse files
app.py
CHANGED
|
@@ -23,6 +23,25 @@ def get_user_models(hf_username, env_tag, lib_tag):
|
|
| 23 |
return user_model_ids
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def get_metadata(model_id):
|
| 27 |
"""
|
| 28 |
Get model metadata (contains evaluation data)
|
|
@@ -208,8 +227,8 @@ def certification(hf_username):
|
|
| 208 |
},
|
| 209 |
{
|
| 210 |
"unit": "Unit 8 PII",
|
| 211 |
-
"env": "
|
| 212 |
-
"library": "
|
| 213 |
"min_result": 100,
|
| 214 |
"best_result": 0,
|
| 215 |
"best_model_id": "",
|
|
@@ -217,8 +236,13 @@ def certification(hf_username):
|
|
| 217 |
},
|
| 218 |
]
|
| 219 |
for unit in results_certification:
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 222 |
|
| 223 |
# Calculate the best result and get the best_model_id
|
| 224 |
best_result, best_model_id = calculate_best_result(user_models)
|
|
|
|
| 23 |
return user_model_ids
|
| 24 |
|
| 25 |
|
| 26 |
+
def get_user_sf_models(hf_username, env_tag, lib_tag):
|
| 27 |
+
api = HfApi()
|
| 28 |
+
models_sf = []
|
| 29 |
+
models = api.list_models(author=hf_username, filter=["reinforcement-learning", lib_tag])
|
| 30 |
+
|
| 31 |
+
user_model_ids = [x.modelId for x in models]
|
| 32 |
+
|
| 33 |
+
for model in user_model_ids:
|
| 34 |
+
meta = get_metadata(model)
|
| 35 |
+
if meta is None:
|
| 36 |
+
continue
|
| 37 |
+
result = meta["model-index"][0]["results"][0]["dataset"]["name"]
|
| 38 |
+
if result == env_tag:
|
| 39 |
+
models_sf.append(model)
|
| 40 |
+
|
| 41 |
+
user_sf_models_ids = [x.modelId for x in models_sf]
|
| 42 |
+
return user_sf_models_ids
|
| 43 |
+
|
| 44 |
+
|
| 45 |
def get_metadata(model_id):
|
| 46 |
"""
|
| 47 |
Get model metadata (contains evaluation data)
|
|
|
|
| 227 |
},
|
| 228 |
{
|
| 229 |
"unit": "Unit 8 PII",
|
| 230 |
+
"env": "doom_health_gathering_supreme",
|
| 231 |
+
"library": "sample-factory",
|
| 232 |
"min_result": 100,
|
| 233 |
"best_result": 0,
|
| 234 |
"best_model_id": "",
|
|
|
|
| 236 |
},
|
| 237 |
]
|
| 238 |
for unit in results_certification:
|
| 239 |
+
if unit["unit"] != "Unit 8 PII":
|
| 240 |
+
# Get user model
|
| 241 |
+
user_models = get_user_models(hf_username, unit['env'], unit['library'])
|
| 242 |
+
# For sample factory vizdoom we don't have env tag for now
|
| 243 |
+
else:
|
| 244 |
+
user_models = get_user_sf_models(hf_username, unit['env'], unit['library'])
|
| 245 |
+
|
| 246 |
|
| 247 |
# Calculate the best result and get the best_model_id
|
| 248 |
best_result, best_model_id = calculate_best_result(user_models)
|