Spaces:
Runtime error
Runtime error
| import gzip | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| import streamlit as st | |
| import torch | |
| import tqdm | |
| from sentence_transformers import SentenceTransformer | |
| def load_model(model_name, model_dict): | |
| assert model_name in model_dict.keys() | |
| # Lazy downloading | |
| model_ids = model_dict[model_name] | |
| if type(model_ids) == str: | |
| output = SentenceTransformer(model_ids) | |
| elif hasattr(model_ids, '__iter__'): | |
| output = [SentenceTransformer(name) for name in model_ids] | |
| return output | |
| def load_embeddings(): | |
| # embedding pre-generated | |
| corpus_emb = torch.from_numpy(np.loadtxt('./data/stackoverflow-titles-distilbert-emb.csv', max_rows=10000)) | |
| return corpus_emb.float() | |
| def filter_questions(tag, max_questions=10000): | |
| posts = [] | |
| max_posts = 6e6 | |
| with gzip.open("./data/stackoverflow-titles.jsonl.gz", "rt") as fIn: | |
| for line in tqdm.auto.tqdm(fIn, total=max_posts, desc="Load data"): | |
| posts.append(json.loads(line)) | |
| if len(posts) >= max_posts: | |
| break | |
| filtered_posts = [] | |
| for post in posts: | |
| if tag in post["tags"]: | |
| filtered_posts.append(post) | |
| if len(filtered_posts) >= max_questions: | |
| break | |
| return filtered_posts | |
| def load_gender_data(): | |
| df = load_gendered_dataset() | |
| sampled_row = df.sample().iloc[0] | |
| return sampled_row.base_sentence, sampled_row.male_sentence, sampled_row.female_sentence | |
| def load_gendered_dataset(): | |
| df = pd.read_csv('./data/bias_evaluation.csv') | |
| return df |