Spaces:
Running
Running
| """ | |
| Main Flask application for the watermark detection web interface. | |
| """ | |
| from flask import Flask, render_template, request, jsonify, Response, stream_with_context | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import json | |
| from ..core.detector import MarylandDetector, MarylandDetectorZ, OpenaiDetector, OpenaiDetectorZ | |
| from ..core.generator import WmGenerator, OpenaiGenerator, MarylandGenerator | |
| from .utils import get_token_details, template_prompt | |
| CACHE_DIR = "wm_interactive/static/hf_cache" | |
| def convert_nan_to_null(obj): | |
| """Convert NaN values to null for JSON serialization""" | |
| import math | |
| if isinstance(obj, float) and math.isnan(obj): | |
| return None | |
| elif isinstance(obj, dict): | |
| return {k: convert_nan_to_null(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [convert_nan_to_null(item) for item in obj] | |
| return obj | |
| def set_to_int(value, default_value = None): | |
| try: | |
| return int(value) | |
| except (ValueError, TypeError): | |
| return default_value | |
| def create_detector(detector_type, tokenizer, **kwargs): | |
| """Create a detector instance based on the specified type.""" | |
| detector_map = { | |
| 'maryland': MarylandDetector, | |
| 'marylandz': MarylandDetectorZ, | |
| 'openai': OpenaiDetector, | |
| 'openaiz': OpenaiDetectorZ | |
| } | |
| # Validate and set default values for parameters | |
| if 'seed' in kwargs: | |
| kwargs['seed'] = set_to_int(kwargs['seed'], default_value = 0) | |
| if 'ngram' in kwargs: | |
| kwargs['ngram'] = set_to_int(kwargs['ngram'], default_value = 1) | |
| detector_class = detector_map.get(detector_type, MarylandDetector) | |
| return detector_class(tokenizer=tokenizer, **kwargs) | |
| def create_app(): | |
| app = Flask(__name__, | |
| static_folder='../static', | |
| template_folder='../templates') | |
| # Add zip to Jinja's global context | |
| app.jinja_env.globals.update(zip=zip) | |
| # Pick a model | |
| # model_id = "meta-llama/Llama-3.2-1B-Instruct" | |
| model_id = "HuggingFaceTB/SmolLM2-135M-Instruct" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR) | |
| model = AutoModelForCausalLM.from_pretrained(model_id, cache_dir=CACHE_DIR).to("cuda" if torch.cuda.is_available() else "cpu") | |
| # Create default generator | |
| generator = MarylandGenerator(model, tokenizer, ngram=1, seed=0) | |
| def index(): | |
| return render_template("index.html") | |
| def tokenize(): | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({'error': 'No JSON data received'}), 400 | |
| text = data.get('text', '') | |
| params = data.get('params', {}) | |
| # Create a detector instance with the provided parameters | |
| detector = create_detector( | |
| detector_type=params.get('detector_type', 'maryland'), | |
| tokenizer=tokenizer, | |
| seed=params.get('seed', 0), | |
| ngram=params.get('ngram', 1) | |
| ) | |
| if text: | |
| try: | |
| display_info = get_token_details(text, detector) | |
| # Extract summary stats (last item in display_info) | |
| stats = display_info.pop() | |
| response_data = { | |
| 'token_count': len(display_info), | |
| 'tokens': [info['token'] for info in display_info], | |
| 'colors': [info['color'] for info in display_info], | |
| 'scores': [info['score'] if info.get('is_scored', False) else None for info in display_info], | |
| 'pvalues': [info['pvalue'] if info.get('is_scored', False) else None for info in display_info], | |
| 'final_score': stats.get('final_score', 0) if stats.get('final_score') is not None else 0, | |
| 'ntoks_scored': stats.get('ntoks_scored', 0) if stats.get('ntoks_scored') is not None else 0, | |
| 'final_pvalue': stats.get('final_pvalue', 0.5) if stats.get('final_pvalue') is not None else 0.5 | |
| } | |
| # Convert any NaN values to null before sending | |
| response_data = convert_nan_to_null(response_data) | |
| # Ensure numeric fields have default values if they became null | |
| if response_data['final_score'] is None: | |
| response_data['final_score'] = 0 | |
| if response_data['ntoks_scored'] is None: | |
| response_data['ntoks_scored'] = 0 | |
| if response_data['final_pvalue'] is None: | |
| response_data['final_pvalue'] = 0.5 | |
| return jsonify(response_data) | |
| except Exception as e: | |
| app.logger.error(f'Error processing text: {str(e)}') | |
| return jsonify({'error': f'Error processing text: {str(e)}'}), 500 | |
| return jsonify({ | |
| 'token_count': 0, | |
| 'tokens': [], | |
| 'colors': [], | |
| 'scores': [], | |
| 'pvalues': [], | |
| 'final_score': 0, | |
| 'ntoks_scored': 0, | |
| 'final_pvalue': 0.5 | |
| }) | |
| except Exception as e: | |
| app.logger.error(f'Server error: {str(e)}') | |
| return jsonify({'error': f'Server error: {str(e)}'}), 500 | |
| def generate(): | |
| try: | |
| data = request.get_json() | |
| if not data: | |
| return jsonify({'error': 'No JSON data received'}), 400 | |
| prompt = template_prompt(data.get('prompt', '')) | |
| params = data.get('params', {}) | |
| temperature = float(params.get('temperature', 0.8)) | |
| def generate_stream(): | |
| try: | |
| # Create generator with correct parameters | |
| generator_class = OpenaiGenerator if params.get('detector_type') == 'openai' else MarylandGenerator | |
| generator = generator_class( | |
| model=model, | |
| tokenizer=tokenizer, | |
| ngram=set_to_int(params.get('ngram', 1)), | |
| seed=set_to_int(params.get('seed', 0)), | |
| delta=float(params.get('delta', 2.0)), | |
| ) | |
| # Get special tokens to filter out | |
| special_tokens = { | |
| '<|im_start|>', '<|im_end|>', | |
| tokenizer.pad_token, tokenizer.eos_token, | |
| tokenizer.bos_token if hasattr(tokenizer, 'bos_token') else None, | |
| tokenizer.sep_token if hasattr(tokenizer, 'sep_token') else None | |
| } | |
| special_tokens = {t for t in special_tokens if t is not None} | |
| # Encode prompt | |
| prompt_tokens = tokenizer.encode(prompt) | |
| prompt_size = len(prompt_tokens) | |
| max_gen_len = 100 | |
| total_len = min(getattr(model.config, 'max_position_embeddings', 2048), max_gen_len + prompt_size) | |
| # Initialize generation | |
| tokens = torch.full((1, total_len), model.config.pad_token_id).to(model.device).long() | |
| tokens[0, :prompt_size] = torch.tensor(prompt_tokens).long() | |
| input_text_mask = tokens != model.config.pad_token_id | |
| # Generate token by token | |
| prev_pos = 0 | |
| outputs = None # Initialize outputs to None | |
| for cur_pos in range(prompt_size, total_len): | |
| # Get model outputs | |
| outputs = model.forward( | |
| tokens[:, prev_pos:cur_pos], | |
| use_cache=True, | |
| past_key_values=outputs.past_key_values if prev_pos > 0 else None | |
| ) | |
| # Sample next token using the generator's sampling method | |
| ngram_tokens = tokens[0, cur_pos-generator.ngram:cur_pos].tolist() | |
| aux = { | |
| 'ngram_tokens': ngram_tokens, | |
| 'cur_pos': cur_pos, | |
| } | |
| next_token = generator.sample_next( | |
| outputs.logits[:, -1, :], | |
| aux, | |
| temperature=temperature, | |
| top_p=0.9 | |
| ) | |
| # Check for EOS token | |
| if next_token == model.config.eos_token_id: | |
| break | |
| # Decode and check if it's a special token | |
| new_text = tokenizer.decode([next_token]) | |
| if new_text not in special_tokens and not any(st in new_text for st in special_tokens): | |
| yield f"data: {json.dumps({'token': new_text, 'done': False})}\n\n" | |
| # Update token and position | |
| tokens[0, cur_pos] = next_token | |
| prev_pos = cur_pos | |
| # Send final complete text, filtering out special tokens | |
| final_tokens = tokens[0, prompt_size:cur_pos+1].tolist() | |
| final_text = tokenizer.decode(final_tokens) | |
| for st in special_tokens: | |
| final_text = final_text.replace(st, '') | |
| yield f"data: {json.dumps({'text': final_text, 'done': True})}\n\n" | |
| except Exception as e: | |
| app.logger.error(f'Error generating text: {str(e)}') | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| return Response(stream_with_context(generate_stream()), mimetype='text/event-stream') | |
| except Exception as e: | |
| app.logger.error(f'Server error: {str(e)}') | |
| return jsonify({'error': f'Server error: {str(e)}'}), 500 | |
| return app | |
| app = create_app() | |
| if __name__ == "__main__": | |
| app.run(host='0.0.0.0', port=7860) |