Spaces:
Sleeping
Sleeping
Vokturz
commited on
Commit
·
fddae32
1
Parent(s):
2d9aa2d
cache default model
Browse files- src/app.py +11 -11
src/app.py
CHANGED
|
@@ -22,6 +22,10 @@ st.markdown(
|
|
| 22 |
def get_gpu_specs():
|
| 23 |
return pd.read_csv("data/gpu_specs.csv")
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
def show_gpu_info(info, trainable_params=0):
|
| 27 |
for var in ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']:
|
|
@@ -46,13 +50,6 @@ def get_name(index):
|
|
| 46 |
row = gpu_specs.iloc[index]
|
| 47 |
return f"{row['Product Name']} ({row['RAM (GB)']} GB, {row['Year']})"
|
| 48 |
|
| 49 |
-
def create_plot(memory_table, y, title, container):
|
| 50 |
-
fig = px.bar(memory_table, x=memory_table.index, y=y, color_continuous_scale="RdBu_r")
|
| 51 |
-
fig.update_layout(yaxis_title="Number of GPUs", title=dict(text=title, font=dict(size=25)))
|
| 52 |
-
fig.update_coloraxes(showscale=False)
|
| 53 |
-
|
| 54 |
-
container.plotly_chart(fig, use_container_width=True)
|
| 55 |
-
|
| 56 |
gpu_specs = get_gpu_specs()
|
| 57 |
|
| 58 |
access_token = st.sidebar.text_input("Access token")
|
|
@@ -61,16 +58,19 @@ if not model_name:
|
|
| 61 |
st.info("Please enter a model name")
|
| 62 |
st.stop()
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
model_name = extract_from_url(model_name)
|
| 67 |
if model_name not in st.session_state:
|
| 68 |
if 'actual_model' in st.session_state:
|
| 69 |
del st.session_state[st.session_state['actual_model']]
|
| 70 |
del st.session_state['actual_model']
|
| 71 |
gc.collect()
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
st.session_state['actual_model'] = model_name
|
| 75 |
|
| 76 |
|
|
|
|
| 22 |
def get_gpu_specs():
|
| 23 |
return pd.read_csv("data/gpu_specs.csv")
|
| 24 |
|
| 25 |
+
@st.cache_resource
|
| 26 |
+
def get_mistralai_table():
|
| 27 |
+
model = get_model("mistralai/Mistral-7B-v0.1", library="transformers", access_token="")
|
| 28 |
+
return calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
|
| 29 |
|
| 30 |
def show_gpu_info(info, trainable_params=0):
|
| 31 |
for var in ['Inference', 'Full Training Adam', 'LoRa Fine-tuning']:
|
|
|
|
| 50 |
row = gpu_specs.iloc[index]
|
| 51 |
return f"{row['Product Name']} ({row['RAM (GB)']} GB, {row['Year']})"
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
gpu_specs = get_gpu_specs()
|
| 54 |
|
| 55 |
access_token = st.sidebar.text_input("Access token")
|
|
|
|
| 58 |
st.info("Please enter a model name")
|
| 59 |
st.stop()
|
| 60 |
|
|
|
|
|
|
|
| 61 |
model_name = extract_from_url(model_name)
|
| 62 |
if model_name not in st.session_state:
|
| 63 |
if 'actual_model' in st.session_state:
|
| 64 |
del st.session_state[st.session_state['actual_model']]
|
| 65 |
del st.session_state['actual_model']
|
| 66 |
gc.collect()
|
| 67 |
+
if model_name == "mistralai/Mistral-7B-v0.1": # cache Mistral
|
| 68 |
+
st.session_state[model_name] = get_mistralai_table()
|
| 69 |
+
else:
|
| 70 |
+
model = get_model(model_name, library="transformers", access_token=access_token)
|
| 71 |
+
st.session_state[model_name] = calculate_memory(model, ["float32", "float16/bfloat16", "int8", "int4"])
|
| 72 |
+
del model
|
| 73 |
+
gc.collect()
|
| 74 |
st.session_state['actual_model'] = model_name
|
| 75 |
|
| 76 |
|