rope / index.html
ucalyptus's picture
Update index.html
2fbba50 verified
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>RoPE Visualization</title>
<script src="https://cdn.tailwindcss.com"></script>
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<script src="https://cdn.jsdelivr.net/npm/mathjs@11.6.0/lib/browser/math.js"></script>
<style>
.rope-vector {
transition: all 0.3s ease;
}
.vector-container {
perspective: 1000px;
}
.gradient-bg {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
}
.control-panel {
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);
}
</style>
</head>
<body class="gradient-bg min-h-screen p-6">
<div class="max-w-6xl mx-auto">
<div class="text-center mb-8">
<h1 class="text-4xl font-bold text-gray-800 mb-2">Rotary Positional Embedding (RoPE) Visualization</h1>
<p class="text-lg text-gray-600">Interactive exploration of how RoPE encodes position information in transformer models</p>
</div>
<div class="grid grid-cols-1 lg:grid-cols-3 gap-6">
<!-- Control Panel -->
<div class="bg-white rounded-xl p-6 control-panel">
<h2 class="text-xl font-semibold mb-4 text-gray-800">Configuration</h2>
<div class="space-y-4">
<div>
<label class="block text-sm font-medium text-gray-700 mb-1">Model Dimension (d)</label>
<select id="dimension" class="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500">
<option value="4">4 (Simplified)</option>
<option value="8">8</option>
<option value="16">16</option>
<option value="32">32</option>
<option value="64" selected>64</option>
<option value="128">128</option>
</select>
</div>
<div>
<label class="block text-sm font-medium text-gray-700 mb-1">Base Frequency (θ)</label>
<input type="range" id="baseFreq" min="1000" max="50000" step="1000" value="10000" class="w-full">
<div class="flex justify-between text-xs text-gray-500">
<span>1,000</span>
<span id="baseFreqValue">10,000</span>
<span>50,000</span>
</div>
</div>
<div>
<label class="block text-sm font-medium text-gray-700 mb-1">Position (p)</label>
<input type="range" id="position" min="0" max="1024" step="1" value="0" class="w-full">
<div class="flex justify-between text-xs text-gray-500">
<span>0</span>
<span id="positionValue">0</span>
<span>1024</span>
</div>
</div>
<div>
<label class="block text-sm font-medium text-gray-700 mb-1">Relative Position (k)</label>
<input type="range" id="relativePos" min="-32" max="32" step="1" value="1" class="w-full">
<div class="flex justify-between text-xs text-gray-500">
<span>-32</span>
<span id="relativePosValue">1</span>
<span>32</span>
</div>
</div>
<div class="pt-2">
<button id="animateBtn" class="w-full bg-indigo-600 text-white py-2 px-4 rounded-md hover:bg-indigo-700 transition-colors">
Animate Rotation
</button>
</div>
</div>
<div class="mt-6 pt-4 border-t border-gray-200">
<h3 class="text-lg font-medium text-gray-800 mb-2">Dot Product Analysis</h3>
<div class="space-y-2">
<div class="flex justify-between">
<span class="text-sm text-gray-600">q · k (original):</span>
<span id="originalDot" class="font-mono">0.00</span>
</div>
<div class="flex justify-between">
<span class="text-sm text-gray-600">R(q,p) · R(k,p+k):</span>
<span id="rotatedDot" class="font-mono">0.00</span>
</div>
<div class="flex justify-between">
<span class="text-sm text-gray-600">Difference:</span>
<span id="dotDifference" class="font-mono">0.00</span>
</div>
</div>
</div>
</div>
<!-- 3D Vector Visualization -->
<div class="bg-white rounded-xl p-6 vector-container">
<h2 class="text-xl font-semibold mb-4 text-gray-800">3D Vector Rotation</h2>
<div class="relative h-64 w-full mb-4">
<canvas id="vectorCanvas" class="absolute inset-0"></canvas>
</div>
<div class="text-sm text-gray-600">
<p>Visualization of how RoPE rotates vector components in 3D space. Each pair of vector dimensions is rotated by an angle determined by the position and frequency.</p>
</div>
</div>
<!-- Frequency Analysis -->
<div class="bg-white rounded-xl p-6">
<h2 class="text-xl font-semibold mb-4 text-gray-800">Frequency Spectrum</h2>
<div class="relative h-64 w-full mb-4">
<canvas id="freqChart"></canvas>
</div>
<div class="text-sm text-gray-600">
<p>Shows the geometric progression of frequencies across dimensions. Lower dimensions have higher frequencies (shorter wavelengths) while higher dimensions have lower frequencies.</p>
</div>
</div>
</div>
<!-- Detailed Vector View -->
<div class="mt-6 bg-white rounded-xl p-6">
<h2 class="text-xl font-semibold mb-4 text-gray-800">Vector Component Rotation</h2>
<div class="overflow-x-auto">
<table class="min-w-full divide-y divide-gray-200">
<thead class="bg-gray-50">
<tr>
<th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Dimension Pair</th>
<th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Original Values</th>
<th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Rotated Values</th>
<th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Rotation Angle</th>
<th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Frequency</th>
</tr>
</thead>
<tbody id="vectorDetails" class="bg-white divide-y divide-gray-200">
<!-- Filled by JavaScript -->
</tbody>
</table>
</div>
</div>
<!-- Explanation Section -->
<div class="mt-6 bg-white rounded-xl p-6">
<h2 class="text-xl font-semibold mb-4 text-gray-800">How RoPE Works</h2>
<div class="prose max-w-none text-gray-700">
<p>Rotary Positional Embedding (RoPE) encodes position information by rotating pairs of vector components:</p>
<ul class="list-disc pl-5 space-y-1">
<li>Each pair of dimensions (2i, 2i+1) is treated as a 2D vector</li>
<li>The vector is rotated by an angle θ = p * ωᵢ where p is position and ωᵢ is frequency</li>
<li>Frequencies decrease geometrically with dimension: ωᵢ = 1/(θ^(2i/d))</li>
<li>This creates relative position encoding in the attention dot product: q·k depends only on m-n</li>
<li>The norm (length) of vectors remains unchanged, preserving semantic information</li>
</ul>
<p class="mt-4">The rotation matrix for each pair is:</p>
<div class="bg-gray-100 p-4 rounded-md overflow-x-auto">
<code class="text-sm">
Mᵢ = [ cos(pωᵢ) -sin(pωᵢ) ]<br>
[ sin(pωᵢ) cos(pωᵢ) ]
</code>
</div>
</div>
</div>
</div>
<script>
// Initialize components
document.addEventListener('DOMContentLoaded', function() {
// Get DOM elements
const dimensionSelect = document.getElementById('dimension');
const baseFreqSlider = document.getElementById('baseFreq');
const baseFreqValue = document.getElementById('baseFreqValue');
const positionSlider = document.getElementById('position');
const positionValue = document.getElementById('positionValue');
const relativePosSlider = document.getElementById('relativePos');
const relativePosValue = document.getElementById('relativePosValue');
const animateBtn = document.getElementById('animateBtn');
const originalDot = document.getElementById('originalDot');
const rotatedDot = document.getElementById('rotatedDot');
const dotDifference = document.getElementById('dotDifference');
const vectorDetails = document.getElementById('vectorDetails');
// Initialize 3D vector visualization
const vectorCanvas = document.getElementById('vectorCanvas');
const vectorCtx = vectorCanvas.getContext('2d');
vectorCanvas.width = vectorCanvas.offsetWidth;
vectorCanvas.height = vectorCanvas.offsetHeight;
// Initialize frequency chart
const freqChartCanvas = document.getElementById('freqChart');
const freqChart = new Chart(freqChartCanvas, {
type: 'line',
data: { labels: [], datasets: [{ data: [], borderColor: '#4f46e5', tension: 0.1 }] },
options: {
responsive: true,
maintainAspectRatio: false,
plugins: { legend: { display: false } },
scales: {
y: { title: { display: true, text: 'Frequency (ωᵢ)' } },
x: { title: { display: true, text: 'Dimension Pair (i)' } }
}
}
});
// State
let state = {
dimension: 64,
baseFreq: 10000,
position: 0,
relativePos: 1,
isAnimating: false,
animationFrame: null
};
// Event listeners
dimensionSelect.addEventListener('change', updateDimension);
baseFreqSlider.addEventListener('input', updateBaseFreq);
positionSlider.addEventListener('input', updatePosition);
relativePosSlider.addEventListener('input', updateRelativePos);
animateBtn.addEventListener('click', toggleAnimation);
// Initial render
updateAll();
// Window resize handler
window.addEventListener('resize', function() {
vectorCanvas.width = vectorCanvas.offsetWidth;
vectorCanvas.height = vectorCanvas.offsetHeight;
renderVectorVisualization();
});
// Update functions
function updateDimension() {
state.dimension = parseInt(dimensionSelect.value);
updateAll();
}
function updateBaseFreq() {
state.baseFreq = parseInt(baseFreqSlider.value);
baseFreqValue.textContent = state.baseFreq.toLocaleString();
updateAll();
}
function updatePosition() {
state.position = parseInt(positionSlider.value);
positionValue.textContent = state.position;
updateAll();
}
function updateRelativePos() {
state.relativePos = parseInt(relativePosSlider.value);
relativePosValue.textContent = state.relativePos;
updateAll();
}
function toggleAnimation() {
state.isAnimating = !state.isAnimating;
animateBtn.textContent = state.isAnimating ? 'Stop Animation' : 'Animate Rotation';
if (state.isAnimating) {
animate();
} else {
cancelAnimationFrame(state.animationFrame);
}
}
function animate() {
if (!state.isAnimating) return;
state.position = (state.position + 1) % 1024;
positionSlider.value = state.position;
positionValue.textContent = state.position;
updateAll();
state.animationFrame = requestAnimationFrame(animate);
}
function updateAll() {
renderVectorVisualization();
renderFrequencyChart();
renderVectorDetails();
calculateDotProducts();
}
// Render functions
function renderVectorVisualization() {
const ctx = vectorCtx;
const width = vectorCanvas.width;
const height = vectorCanvas.height;
const centerX = width / 2;
const centerY = height / 2;
const scale = Math.min(width, height) * 0.3;
// Clear canvas
ctx.clearRect(0, 0, width, height);
// Draw grid
ctx.strokeStyle = '#e5e7eb';
ctx.lineWidth = 1;
// X axis
ctx.beginPath();
ctx.moveTo(0, centerY);
ctx.lineTo(width, centerY);
ctx.stroke();
// Y axis
ctx.beginPath();
ctx.moveTo(centerX, 0);
ctx.lineTo(centerX, height);
ctx.stroke();
// Generate random vector (same for all positions for consistency)
const vector = Array.from({length: state.dimension}, () => math.random(-1, 1));
// Calculate frequencies
const frequencies = calculateFrequencies();
// Select first 3 dimensions for visualization (simplified 3D view)
const dim1 = 0;
const dim2 = 1;
const dim3 = 2;
// Original vector components
const x1 = vector[dim1];
const y1 = vector[dim2];
const z1 = vector[dim3];
// Rotated vector components
const angle1 = state.position * frequencies[Math.floor(dim1/2)];
const angle2 = state.position * frequencies[Math.floor(dim2/2)];
const rotX1 = x1 * math.cos(angle1) - y1 * math.sin(angle1);
const rotY1 = x1 * math.sin(angle1) + y1 * math.cos(angle1);
const rotZ1 = z1; // Not rotated in this pair
// Project 3D to 2D with simple perspective
const project = (x, y, z) => {
const perspective = 1 + z * 0.2;
return {
x: centerX + x * scale * perspective,
y: centerY - y * scale * perspective
};
};
// Draw original vector
const origProj = project(x1, y1, z1);
ctx.strokeStyle = '#10b981';
ctx.lineWidth = 2;
ctx.beginPath();
ctx.moveTo(centerX, centerY);
ctx.lineTo(origProj.x, origProj.y);
ctx.stroke();
// Draw rotated vector
const rotProj = project(rotX1, rotY1, rotZ1);
ctx.strokeStyle = '#3b82f6';
ctx.lineWidth = 2;
ctx.beginPath();
ctx.moveTo(centerX, centerY);
ctx.lineTo(rotProj.x, rotProj.y);
ctx.stroke();
// Draw labels
ctx.fillStyle = '#111827';
ctx.font = '12px sans-serif';
ctx.fillText('Original', origProj.x + 5, origProj.y - 5);
ctx.fillText('Rotated', rotProj.x + 5, rotProj.y + 15);
// Draw dimension indicators
ctx.fillStyle = '#6b7280';
ctx.font = '10px sans-serif';
ctx.fillText(`Dimensions ${dim1},${dim2}`, 10, 20);
}
function renderFrequencyChart() {
const frequencies = calculateFrequencies();
const labels = frequencies.map((_, i) => i+1);
freqChart.data.labels = labels;
freqChart.data.datasets[0].data = frequencies;
freqChart.update();
}
function renderVectorDetails() {
// Clear existing rows
vectorDetails.innerHTML = '';
// Calculate frequencies
const frequencies = calculateFrequencies();
// Generate random vector (consistent for display)
const vector = Array.from({length: state.dimension}, () => math.random(-1, 1));
// Show first 5 dimension pairs (for space)
const pairsToShow = Math.min(5, Math.floor(state.dimension/2));
for (let i = 0; i < pairsToShow; i++) {
const dim1 = 2*i;
const dim2 = 2*i + 1;
// Original values
const val1 = vector[dim1];
const val2 = vector[dim2];
// Rotation angle
const angle = state.position * frequencies[i];
// Rotated values
const rotVal1 = val1 * math.cos(angle) - val2 * math.sin(angle);
const rotVal2 = val1 * math.sin(angle) + val2 * math.cos(angle);
// Create row
const row = document.createElement('tr');
row.className = 'rope-vector';
row.innerHTML = `
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-900">${dim1}, ${dim2}</td>
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-500">${val1.toFixed(3)}, ${val2.toFixed(3)}</td>
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-blue-600">${rotVal1.toFixed(3)}, ${rotVal2.toFixed(3)}</td>
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-900">${angle.toFixed(5)} rad</td>
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-900">${frequencies[i].toExponential(3)}</td>
`;
vectorDetails.appendChild(row);
}
// Add ellipsis if there are more pairs
if (pairsToShow < Math.floor(state.dimension/2)) {
const row = document.createElement('tr');
row.innerHTML = `
<td colspan="5" class="px-6 py-2 text-center text-sm text-gray-500">... ${Math.floor(state.dimension/2) - pairsToShow} more pairs ...</td>
`;
vectorDetails.appendChild(row);
}
}
function calculateDotProducts() {
// Generate random vectors
const q = Array.from({length: state.dimension}, () => math.random(-1, 1));
const k = Array.from({length: state.dimension}, () => math.random(-1, 1));
// Calculate frequencies
const frequencies = calculateFrequencies();
// Original dot product
const originalDotValue = dotProduct(q, k);
// Apply RoPE to both vectors
const rotatedQ = applyRoPE(q, state.position, frequencies);
const rotatedK = applyRoPE(k, state.position + state.relativePos, frequencies);
// Rotated dot product
const rotatedDotValue = dotProduct(rotatedQ, rotatedK);
// Update display
originalDot.textContent = originalDotValue.toFixed(4);
rotatedDot.textContent = rotatedDotValue.toFixed(4);
dotDifference.textContent = Math.abs(originalDotValue - rotatedDotValue).toFixed(4);
}
// Helper functions
function calculateFrequencies() {
const frequencies = [];
const numPairs = Math.floor(state.dimension / 2);
for (let i = 0; i < numPairs; i++) {
frequencies.push(1 / Math.pow(state.baseFreq, (2 * i) / state.dimension));
}
return frequencies;
}
function applyRoPE(vector, position, frequencies) {
const rotated = [...vector];
for (let i = 0; i < Math.floor(vector.length / 2); i++) {
const dim1 = 2 * i;
const dim2 = 2 * i + 1;
const val1 = vector[dim1];
const val2 = vector[dim2];
const angle = position * frequencies[i];
rotated[dim1] = val1 * math.cos(angle) - val2 * math.sin(angle);
rotated[dim2] = val1 * math.sin(angle) + val2 * math.cos(angle);
}
return rotated;
}
function dotProduct(a, b) {
return a.reduce((sum, val, i) => sum + val * b[i], 0);
}
});
</script>
</body>
</html>