""" Quick test script for B2NL v6.1.2 app functionality """ import sys from pathlib import Path import torch # Add path parent_dir = Path(__file__).parent.parent.parent sys.path.insert(0, str(parent_dir / 'intelligent-tokenizer_v6.1.2')) from core.unified_model import IntelligentTokenizerModelV61 from core.byte_tokenizer_v6 import ByteTokenizerV6 def test_model(): device = torch.device('cpu') tokenizer = ByteTokenizerV6(max_seq_len=64) model = IntelligentTokenizerModelV61(vocab_size=260, max_seq_len=64).to(device) # Load checkpoint checkpoint_path = parent_dir / 'intelligent-tokenizer_v6.1.2' / 'checkpoints' / 'v612_compression_first' / 'best_model.pt' if checkpoint_path.exists(): print(f"Loading checkpoint from {checkpoint_path}") checkpoint = torch.load(str(checkpoint_path), map_location=device) model.load_state_dict(checkpoint['model_state_dict']) print(f"[OK] Loaded checkpoint: Epoch {checkpoint.get('epoch', 'N/A')}") model.eval() # Test Korean text test_text = "안녕하세요. 오늘 날씨가 좋네요." print(f"\nTest text: {test_text}") # Encode byte_seq = list(test_text.encode('utf-8'))[:62] print(f"Bytes: {len(byte_seq)}") # Prepare input input_ids = torch.tensor([[tokenizer.BOS] + byte_seq + [tokenizer.EOS]], dtype=torch.long).to(device) if input_ids.size(1) < 64: padding = torch.full((1, 64 - input_ids.size(1)), tokenizer.PAD, dtype=torch.long).to(device) input_ids = torch.cat([input_ids, padding], dim=1) attention_mask = (input_ids != tokenizer.PAD).float() # Forward pass - v6.1.2 production mode with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask, labels=input_ids, epoch=233, # Match checkpoint epoch for best performance use_cross_attention=True # Enable cross-attention for better reconstruction ) print(f"\n[OK] Model outputs available: {list(outputs.keys())}") # Check boundaries for groups if 'eojeol_boundaries' in outputs: boundaries = torch.argmax(outputs['eojeol_boundaries'], dim=-1)[0] num_groups = torch.sum(boundaries == 1).item() + 1 compression = len(byte_seq) / num_groups print(f"[OK] Compression: {len(byte_seq)} bytes -> {num_groups} tokens = {compression:.1f}:1") # Visualize groups groups = [] current_group = [] boundaries_np = boundaries.cpu().numpy() for i in range(min(len(byte_seq), len(boundaries_np))): is_boundary = (i == 0) or (boundaries_np[i] == 1) if is_boundary and current_group: try: group_text = bytes(current_group).decode('utf-8', errors='replace') groups.append(f"<{group_text}>") except: groups.append(f"<{len(current_group)}B>") current_group = [] if i < len(byte_seq): current_group.append(byte_seq[i]) if current_group: try: group_text = bytes(current_group).decode('utf-8', errors='replace') groups.append(f"<{group_text}>") except: groups.append(f"<{len(current_group)}B>") print(f"[OK] Groups: {' '.join(groups)}") # Check embeddings if 'encoder_hidden_states' in outputs: # encoder_hidden_states is a tuple of all layer outputs last_hidden = outputs['encoder_hidden_states'][-1] if isinstance(outputs['encoder_hidden_states'], tuple) else outputs['encoder_hidden_states'] embeddings = last_hidden[0, 0, :20] # First token, first 20 dims emb_values = embeddings.cpu().numpy() print(f"\n[OK] Embeddings (first 20 dims):") for i in range(0, len(emb_values), 5): dims = emb_values[i:min(i+5, len(emb_values))] dim_strs = [f'{v:7.4f}' for v in dims] print(f" Dim {i:2d}-{min(i+4, len(emb_values)-1):2d}: [{', '.join(dim_strs)}]") print(f"\n Stats - Mean: {emb_values.mean():.4f}, Std: {emb_values.std():.4f}, Min: {emb_values.min():.4f}, Max: {emb_values.max():.4f}") # Check reconstruction if 'logits' in outputs: pred_ids = outputs['logits'].argmax(dim=-1)[0] # Find valid length valid_length = 64 for i in range(1, len(pred_ids)): if pred_ids[i] == 256 or pred_ids[i] == 258: valid_length = i break pred_ids = pred_ids[1:valid_length] pred_ids = pred_ids[pred_ids < 256] if len(pred_ids) > 0: try: reconstructed = bytes(pred_ids.cpu().numpy()).decode('utf-8', errors='ignore') print(f"\n[OK] Reconstructed: {reconstructed}") # Calculate accuracy orig_text = test_text[:len(reconstructed)] matches = sum(1 for o, r in zip(orig_text, reconstructed) if o == r) accuracy = (matches / len(orig_text)) * 100 print(f"[OK] Accuracy: {accuracy:.1f}%") except: print("[ERROR] Reconstruction decode error") print("\n[SUCCESS] All tests passed!") else: print(f"[ERROR] Checkpoint not found at {checkpoint_path}") return False return True if __name__ == "__main__": print("="*60) print("B2NL v6.1.2 App Test") print("="*60) success = test_model() if success: print("\n[READY] Ready to run the Gradio app!") print("Run: python app.py") else: print("\n[WARNING] Please check the checkpoint path")