Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| from datasets import load_dataset | |
| import streamlit as st | |
| from data_utils import get_embedding | |
| from bokeh.plotting import figure,show | |
| from bokeh.io import push_notebook, output_notebook | |
| # output_notebook() | |
| from bokeh.palettes import d3 | |
| from bokeh.models import ColumnDataSource, Grid, LinearAxis, Plot, Scatter | |
| from bokeh.transform import factor_cmap, factor_mark | |
| import base64 | |
| from io import BytesIO | |
| label_columns=["gender","subCategory","masterCategory"] | |
| model_interest=['facebook/deit-tiny-patch16-224', # very small model 5M param model | |
| 'microsoft/beit-base-patch16-224', # big model | |
| "facebook/dino-vits8", | |
| "facebook/levit-128S"] | |
| def convert_base64(img): | |
| buffered = BytesIO() | |
| img.save(buffered, format="JPEG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return "data:image/jpeg;base64,"+img_str | |
| def cache_embedding(model_name): | |
| dataset=load_dataset("ceyda/fashion-products-small", split="train") | |
| dataset=dataset.shuffle(seed=100) #pick a random seed | |
| viz_dat=dataset.train_test_split(0.1,shuffle=False) #일부를 visualization위해서 뽑시단 | |
| viz_dat=viz_dat["test"] | |
| embedding = get_embedding(model_name,viz_dat) | |
| embedding["image"]=embedding["image"].apply(convert_base64) | |
| labels = {label:viz_dat.unique(label) for label in label_columns} | |
| return embedding,labels | |
| def cache_graph(model_name,color_column): | |
| embedding,labels=cache_embedding(model_name) | |
| color_palette = (d3['Category20'][20]+d3['Category20b'][20]+d3['Category20c'][20])[:len(labels[color_column])] | |
| source = ColumnDataSource(data=embedding) | |
| # colors = factor_cmap('gender', palette=["purple","navy","green","blue","pink"], factors=embedding["gender"].unique()) | |
| TOOLS="hover,crosshair,pan,wheel_zoom,zoom_in,zoom_out,box_zoom,reset,tap,save,box_select,lasso_select," | |
| TOOLTIPS = """ | |
| <div> | |
| <div> | |
| <img | |
| src="@image" height="42" alt="@image" width="42" | |
| style="float: left; margin: 0px 15px 15px 0px;" | |
| border="2" | |
| ></img> | |
| </div> | |
| """ | |
| p = figure(tools=TOOLS,tooltips=TOOLTIPS) | |
| p.scatter(x="x", y="y", source=source, | |
| # marker=factor_mark('gender', ['circle', 'circle_cross', 'circle_dot','circle_x','circle_y'], labels["gender"]), | |
| color=factor_cmap(color_column, color_palette, labels[color_column]) | |
| ) | |
| return p | |
| st.write("It takes some time for the graph to load...wait please") | |
| model_name=st.sidebar.selectbox("Model",model_interest) | |
| color_column=st.selectbox("Color by",label_columns) | |
| p=cache_graph(model_name,color_column) | |
| st.bokeh_chart(p, use_container_width=False) |