Spaces:
Sleeping
Sleeping
| import duckdb | |
| import gradio as gr | |
| import polars as pl | |
| from datasets import load_dataset | |
| from gradio_huggingfacehub_search import HuggingfaceHubSearch | |
| from model2vec import StaticModel | |
| global ds | |
| global df | |
| # Load a model from the HuggingFace hub (in this case the potion-base-8M model) | |
| model_name = "minishlab/potion-base-8M" | |
| model = StaticModel.from_pretrained(model_name) | |
| def get_iframe(hub_repo_id): | |
| if not hub_repo_id: | |
| raise ValueError("Hub repo id is required") | |
| url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer" | |
| iframe = f""" | |
| <iframe | |
| src="{url}" | |
| frameborder="0" | |
| width="100%" | |
| height="600px" | |
| ></iframe> | |
| """ | |
| return iframe | |
| def load_dataset_from_hub(hub_repo_id): | |
| global ds | |
| ds = load_dataset(hub_repo_id) | |
| def get_columns(split: str): | |
| global ds | |
| ds_split = ds[split] | |
| return gr.Dropdown( | |
| choices=ds_split.column_names, | |
| value=ds_split.column_names[0], | |
| label="Select a column", | |
| ) | |
| def get_splits(): | |
| global ds | |
| splits = list(ds.keys()) | |
| return gr.Dropdown(choices=splits, value=splits[0], label="Select a split") | |
| def vectorize_dataset(split: str, column: str): | |
| global df | |
| global ds | |
| df = ds[split].to_polars() | |
| embeddings = model.encode(df[column], max_length=512 * 4) | |
| df = df.with_columns(pl.Series(embeddings).alias("embeddings")) | |
| def run_query(query: str): | |
| global df | |
| vector = model.encode(query) | |
| return duckdb.sql( | |
| query=f""" | |
| SELECT * | |
| FROM df | |
| ORDER BY array_cosine_distance(embeddings, {vector.tolist()}::FLOAT[256]) | |
| LIMIT 5 | |
| """ | |
| ).to_df() | |
| with gr.Blocks() as demo: | |
| gr.HTML( | |
| """ | |
| <h1>Vector Search any Hugging Face Dataset</h1> | |
| <p> | |
| This app allows you to vector search any Hugging Face dataset. | |
| You can search for the nearest neighbors of a query vector, or | |
| perform a similarity search on a dataframe. | |
| </p> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| search_in = HuggingfaceHubSearch( | |
| label="Search Huggingface Hub", | |
| placeholder="Search for models on Huggingface", | |
| search_type="dataset", | |
| sumbit_on_select=True, | |
| ) | |
| with gr.Row(): | |
| search_out = gr.HTML(label="Search Results") | |
| with gr.Row(variant="panel"): | |
| split_dropdown = gr.Dropdown(label="Select a split") | |
| column_dropdown = gr.Dropdown(label="Select a column") | |
| with gr.Row(variant="panel"): | |
| query_input = gr.Textbox(label="Query") | |
| search_in.submit(get_iframe, inputs=search_in, outputs=search_out).then( | |
| fn=load_dataset_from_hub, | |
| inputs=search_in, | |
| show_progress=True, | |
| ).then(fn=get_splits, outputs=split_dropdown).then( | |
| fn=get_columns, inputs=split_dropdown, outputs=column_dropdown | |
| ) | |
| split_dropdown.change( | |
| fn=get_columns, inputs=split_dropdown, outputs=column_dropdown | |
| ).then(fn=vectorize_dataset, inputs=[split_dropdown, column_dropdown]) | |
| btn_run = gr.Button("Run") | |
| results_output = gr.Dataframe(label="Results") | |
| btn_run.click(fn=run_query, inputs=query_input, outputs=results_output) | |
| demo.launch() | |