Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from finetune_augmentor import AugmentationExample, AugmentationConfig, FinetuningDataAugmentor | |
| import json | |
| import streamlit.components.v1 as components | |
| from streamlit_ace import st_ace # Editable code block | |
| # ------------------------------- | |
| # Page Configuration and CSS | |
| # ------------------------------- | |
| st.set_page_config( | |
| page_title="Finetuning Data Augmentation Generator", | |
| layout="wide", | |
| initial_sidebar_state="expanded", | |
| ) | |
| components.html( | |
| """ | |
| <div style="position: fixed; top: 10px; right: 10px; z-index: 100;"> | |
| <a href="https://github.com/zamalali/ftboost" target="_blank"> | |
| <img src="https://github.githubassets.com/images/modules/logos_page/GitHub-Mark.png" alt="GitHub" style="height: 30px; margin-right: 10px;"> | |
| </a> | |
| <a href="https://huggingface.co/zamal" target="_blank"> | |
| <img src="https://huggingface.co/front/assets/huggingface_logo.svg" alt="Hugging Face" style="height: 30px;"> | |
| </a> | |
| </div> | |
| """, | |
| height=40 | |
| ) | |
| st.markdown( | |
| """ | |
| <style> | |
| /* Main content area */ | |
| .block-container { | |
| background-color: #121212; | |
| color: #ffffff; | |
| } | |
| /* Sidebar styling */ | |
| [data-testid="stSidebar"] { | |
| background-color: #121212; | |
| color: #ffffff; | |
| } | |
| [data-testid="stSidebar"] * { | |
| color: #ffffff !important; | |
| } | |
| /* Button styling */ | |
| .stButton>button, .stDownloadButton>button { | |
| background-color: #808080 !important; | |
| color: #ffffff !important; | |
| font-size: 16px; | |
| border: none; | |
| border-radius: 5px; | |
| padding: 0.5rem 1.5rem; | |
| margin-top: 1rem; | |
| } | |
| /* Text inputs */ | |
| .stTextInput>div>input, .stNumberInput>div>input { | |
| border-radius: 5px; | |
| border: 1px solid #ffffff; | |
| padding: 0.5rem; | |
| background-color: #1a1a1a; | |
| color: #ffffff; | |
| } | |
| .stTextArea>textarea { | |
| background-color: #1a1a1a; | |
| color: #ffffff; | |
| font-family: "Courier New", monospace; | |
| border: 1px solid #ffffff; | |
| border-radius: 5px; | |
| padding: 1rem; | |
| } | |
| /* Header colors */ | |
| h1 { color: #00FF00; } | |
| h2, h3, h4 { color: #FFFF00; } | |
| /* Field labels */ | |
| label { color: #ffffff !important; } | |
| /* Remove extra margin in code blocks */ | |
| pre { margin: 0; } | |
| /* Ace editor style overrides */ | |
| .ace_editor { | |
| border: none !important; | |
| box-shadow: none !important; | |
| background-color: #121212 !important; | |
| } | |
| /* Override alert (error/success) text colors */ | |
| [data-testid="stAlert"] { color: #ffffff !important; } | |
| /* Add white border to expander header */ | |
| [data-testid="stExpander"] > div:first-child { | |
| border: 1px solid #ffffff !important; | |
| } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Inject JavaScript to scroll to top on load | |
| components.html( | |
| """ | |
| <script> | |
| document.addEventListener("DOMContentLoaded", function() { | |
| setTimeout(function() { window.scrollTo(0, 0); }, 100); | |
| }); | |
| </script> | |
| """, | |
| height=0, | |
| ) | |
| # ------------------------------- | |
| # App Title and Description | |
| # ------------------------------- | |
| st.title("ftBoost π") | |
| st.markdown( | |
| """ | |
| **ftBoost Hero** is a powerful tool designed to help you generate high-quality fine-tuning data for AI models. | |
| Whether you're working with OpenAI, Gemini, Mistral, or LLaMA models, this app allows you to create structured | |
| input-output pairs and apply augmentation techniques to enhance dataset quality. With advanced tuning parameters, | |
| semantic similarity controls, and fluency optimization, **ftBoost Hero** ensures that your fine-tuning data is diverse, | |
| well-structured, and ready for training. π | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # ------------------------------- | |
| # Step A: File Upload & Auto-Detection | |
| # ------------------------------- | |
| st.markdown("##### Step 1: Upload Your Finetuning data JSONL File if you have one already (Optional)") | |
| uploaded_file = st.file_uploader("Upload your train.jsonl file", type=["jsonl", "txt"]) | |
| uploaded_examples = [] | |
| detected_model = None | |
| if uploaded_file is not None: | |
| try: | |
| file_content = uploaded_file.getvalue().decode("utf-8") | |
| # Auto-detect model type from the first valid snippet | |
| for line in file_content.splitlines(): | |
| if line.strip(): | |
| record = json.loads(line) | |
| if "messages" in record: | |
| msgs = record["messages"] | |
| if len(msgs) >= 3 and msgs[0].get("role") == "system": | |
| detected_model = "OpenAI Models" | |
| elif len(msgs) == 2: | |
| detected_model = "Mistral Models" | |
| elif "contents" in record: | |
| detected_model = "Gemini Models" | |
| break | |
| # Display an info message based on detection result | |
| if detected_model is not None: | |
| st.info(f"This JSONL file format supports the **{detected_model}**.") | |
| else: | |
| st.info("The uploaded JSONL file format is not recognized. Please manually select the appropriate model.") | |
| # Process the entire file for valid examples | |
| for line in file_content.splitlines(): | |
| if not line.strip(): | |
| continue | |
| record = json.loads(line) | |
| input_text, output_text = "", "" | |
| if "messages" in record: | |
| msgs = record["messages"] | |
| if len(msgs) >= 3: | |
| input_text = msgs[1].get("content", "").strip() | |
| output_text = msgs[2].get("content", "").strip() | |
| elif len(msgs) == 2: | |
| input_text = msgs[0].get("content", "").strip() | |
| output_text = msgs[1].get("content", "").strip() | |
| elif "contents" in record: | |
| contents = record["contents"] | |
| if len(contents) >= 2 and "parts" in contents[0] and "parts" in contents[1]: | |
| input_text = contents[0]["parts"][0].get("text", "").strip() | |
| output_text = contents[1]["parts"][0].get("text", "").strip() | |
| if input_text and output_text: | |
| uploaded_examples.append(AugmentationExample(input_text=input_text, output_text=output_text)) | |
| if len(uploaded_examples) < 3: | |
| st.error("Uploaded file does not contain at least 3 valid input/output pairs.") | |
| else: | |
| st.success(f"Uploaded file processed: {len(uploaded_examples)} valid input/output pairs loaded.") | |
| except Exception as e: | |
| st.error(f"Error processing uploaded file: {e}") | |
| # ------------------------------- | |
| # Step B: Model Selection | |
| # ------------------------------- | |
| default_model = detected_model if detected_model is not None else "OpenAI Models" | |
| model_options = ["OpenAI Models", "Gemini Models", "Mistral Models", "Llama Models"] | |
| default_index = model_options.index(default_model) if default_model in model_options else 0 | |
| model_type = st.selectbox( | |
| "Select the output format for finetuning", | |
| model_options, | |
| index=default_index | |
| ) | |
| # ------------------------------- | |
| # Step C: System Message & API Key | |
| # ------------------------------- | |
| system_message = st.text_input("System Message (optional) only for OpenAI models", value="Marv is a factual chatbot that is also sarcastic.") | |
| # groq_api_key = st.text_input("LangChain Groq API Key", type="password", help="Enter your LangChain Groq API Key for data augmentation") | |
| groq_api_key = st.text_input( | |
| "LangChain Groq API Key (if you don't have one, get it from [here](https://console.groq.com/keys))", | |
| type="password", | |
| help="Enter your LangChain Groq API Key for data augmentation" | |
| ) | |
| # ------------------------------- | |
| # Step D: Input Schema Template Display | |
| # ------------------------------- | |
| st.markdown("#### Input Schema Template") | |
| if model_type == "OpenAI Models": | |
| st.code( | |
| '''{ | |
| "messages": [ | |
| {"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."}, | |
| {"role": "user", "content": "What's the capital of France?"}, | |
| {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."} | |
| ] | |
| }''', language="json") | |
| elif model_type == "Gemini Models": | |
| st.code( | |
| '''{ | |
| "contents": [ | |
| {"role": "user", "parts": [{"text": "What's the capital of France?"}]}, | |
| {"role": "model", "parts": [{"text": "Paris, as if everyone doesn't know that already."}]} | |
| ] | |
| }''', language="json") | |
| else: | |
| st.code( | |
| '''{ | |
| "messages": [ | |
| {"role": "user", "content": "What's the capital of France?"}, | |
| {"role": "assistant", "content": "Paris, as if everyone doesn't know that already."} | |
| ] | |
| }''', language="json") | |
| # ------------------------------- | |
| # Step E: Manual Input of Pairs (if no file uploaded) | |
| # ------------------------------- | |
| if uploaded_file is None: | |
| st.markdown("##### Enter at least 3 input/output pairs manually:") | |
| num_pairs = st.number_input("Number of Pairs", min_value=3, value=3, step=1) | |
| pair_templates = [] | |
| for i in range(num_pairs): | |
| st.markdown(f"##### Pair {i+1}") | |
| if model_type == "OpenAI Models": | |
| init_template = ('''{ | |
| "messages": [ | |
| {"role": "system", "content": "''' + system_message + '''"}, | |
| {"role": "user", "content": "Enter your input text here"}, | |
| {"role": "assistant", "content": "Enter your output text here"} | |
| ] | |
| }''').strip() | |
| ace_key = f"pair_{i}_{model_type}_{system_message}" | |
| elif model_type == "Gemini Models": | |
| init_template = ('''{ | |
| "contents": [ | |
| {"role": "user", "parts": [{"text": "Enter your input text here"}]}, | |
| {"role": "model", "parts": [{"text": "Enter your output text here"}]} | |
| ] | |
| }''').strip() | |
| ace_key = f"pair_{i}_{model_type}" | |
| else: | |
| init_template = ('''{ | |
| "messages": [ | |
| {"role": "user", "content": "Enter your input text here"}, | |
| {"role": "assistant", "content": "Enter your output text here"} | |
| ] | |
| }''').strip() | |
| ace_key = f"pair_{i}_{model_type}" | |
| pair = st_ace( | |
| placeholder="Edit your code here...", | |
| value=init_template, | |
| language="json", | |
| theme="monokai", | |
| key=ace_key, | |
| height=150 | |
| ) | |
| pair_templates.append(pair) | |
| # ------------------------------- | |
| # Step F: Augmentation Settings | |
| # ------------------------------- | |
| target_augmented = st.number_input("Number of Augmented Pairs to Generate", min_value=5, value=5, step=1) | |
| finetuning_goal = "Improve conversational clarity and capture subtle nuances" | |
| st.markdown(f"**Finetuning Goal:** {finetuning_goal}") | |
| with st.expander("Show/Hide Advanced Tuning Parameters"): | |
| min_semantic = st.slider("Minimum Semantic Similarity", 0.0, 1.0, 0.80, 0.01) | |
| max_semantic = st.slider("Maximum Semantic Similarity", 0.0, 1.0, 0.95, 0.01) | |
| min_diversity = st.slider("Minimum Diversity Score", 0.0, 1.0, 0.70, 0.01) | |
| min_fluency = st.slider("Minimum Fluency Score", 0.0, 1.0, 0.80, 0.01) | |
| # ------------------------------- | |
| # Step G: Generate Data Button and Pipeline Execution | |
| # ------------------------------- | |
| if st.button("Generate Data"): | |
| if not groq_api_key.strip(): | |
| st.error("Please enter your LangChain Groq API Key to proceed.") | |
| st.stop() | |
| # Choose examples: from uploaded file if available; otherwise from manual input. | |
| if uploaded_file is not None and len(uploaded_examples) >= 3: | |
| examples = uploaded_examples | |
| else: | |
| examples = [] | |
| errors = [] | |
| for idx, pair in enumerate(pair_templates): | |
| try: | |
| record = json.loads(pair) | |
| if model_type == "OpenAI Models": | |
| msgs = record.get("messages", []) | |
| if len(msgs) != 3: | |
| raise ValueError("Expected 3 messages") | |
| input_text = msgs[1].get("content", "").strip() | |
| output_text = msgs[2].get("content", "").strip() | |
| elif model_type == "Gemini Models": | |
| contents = record.get("contents", []) | |
| if len(contents) < 2: | |
| raise ValueError("Expected at least 2 contents") | |
| input_text = contents[0]["parts"][0].get("text", "").strip() | |
| output_text = contents[1]["parts"][0].get("text", "").strip() | |
| else: | |
| msgs = record.get("messages", []) | |
| if len(msgs) != 2: | |
| raise ValueError("Expected 2 messages for this format") | |
| input_text = msgs[0].get("content", "").strip() | |
| output_text = msgs[1].get("content", "").strip() | |
| if not input_text or not output_text: | |
| raise ValueError("Input or output text is empty") | |
| examples.append(AugmentationExample(input_text=input_text, output_text=output_text)) | |
| except Exception as e: | |
| errors.append(f"Error in pair {idx+1}: {e}") | |
| if errors: | |
| st.error("There were errors in your input pairs:\n" + "\n".join(errors)) | |
| elif len(examples) < 3: | |
| st.error("Please provide at least 3 valid pairs.") | |
| if len(examples) >= 3: | |
| target_model = "mixtral-8x7b-32768" | |
| try: | |
| config = AugmentationConfig( | |
| target_model=target_model, | |
| examples=examples, | |
| finetuning_goal=finetuning_goal, | |
| groq_api_key=groq_api_key, | |
| system_message=system_message, | |
| min_semantic_similarity=min_semantic, | |
| max_semantic_similarity=max_semantic, | |
| min_diversity_score=min_diversity, | |
| min_fluency_score=min_fluency | |
| ) | |
| except Exception as e: | |
| st.error(f"Configuration error: {e}") | |
| st.stop() | |
| st.markdown('<p style="color: white;">Running augmentation pipeline... Please wait.</p>', unsafe_allow_html=True) | |
| augmentor = FinetuningDataAugmentor(config) | |
| augmentor.run_augmentation(target_count=target_augmented) | |
| fmt = model_type.lower() | |
| if fmt == "openai models": | |
| output_data = augmentor.get_formatted_output(format_type="openai") | |
| elif fmt == "gemini models": | |
| output_data = augmentor.get_formatted_output(format_type="gemini") | |
| elif fmt == "mistral models": | |
| output_data = augmentor.get_formatted_output(format_type="mistral") | |
| elif fmt == "llama models": | |
| output_data = augmentor.get_formatted_output(format_type="llama") | |
| else: | |
| output_data = augmentor.get_formatted_output(format_type="openai") | |
| st.markdown("### Augmented Data") | |
| st.code(output_data, language="json") | |
| st.download_button("Download train.jsonl", output_data, file_name="train.jsonl") | |