|
|
<!DOCTYPE html> |
|
|
<html lang="en"> |
|
|
<head> |
|
|
<meta charset="UTF-8"> |
|
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
|
<title>RoPE Visualization</title> |
|
|
<script src="https://cdn.tailwindcss.com"></script> |
|
|
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script> |
|
|
<script src="https://cdn.jsdelivr.net/npm/mathjs@11.6.0/lib/browser/math.js"></script> |
|
|
<style> |
|
|
.rope-vector { |
|
|
transition: all 0.3s ease; |
|
|
} |
|
|
.vector-container { |
|
|
perspective: 1000px; |
|
|
} |
|
|
.gradient-bg { |
|
|
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); |
|
|
} |
|
|
.control-panel { |
|
|
box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06); |
|
|
} |
|
|
</style> |
|
|
</head> |
|
|
<body class="gradient-bg min-h-screen p-6"> |
|
|
<div class="max-w-6xl mx-auto"> |
|
|
<div class="text-center mb-8"> |
|
|
<h1 class="text-4xl font-bold text-gray-800 mb-2">Rotary Positional Embedding (RoPE) Visualization</h1> |
|
|
<p class="text-lg text-gray-600">Interactive exploration of how RoPE encodes position information in transformer models</p> |
|
|
</div> |
|
|
|
|
|
<div class="grid grid-cols-1 lg:grid-cols-3 gap-6"> |
|
|
|
|
|
<div class="bg-white rounded-xl p-6 control-panel"> |
|
|
<h2 class="text-xl font-semibold mb-4 text-gray-800">Configuration</h2> |
|
|
|
|
|
<div class="space-y-4"> |
|
|
<div> |
|
|
<label class="block text-sm font-medium text-gray-700 mb-1">Model Dimension (d)</label> |
|
|
<select id="dimension" class="w-full px-3 py-2 border border-gray-300 rounded-md shadow-sm focus:outline-none focus:ring-indigo-500 focus:border-indigo-500"> |
|
|
<option value="4">4 (Simplified)</option> |
|
|
<option value="8">8</option> |
|
|
<option value="16">16</option> |
|
|
<option value="32">32</option> |
|
|
<option value="64" selected>64</option> |
|
|
<option value="128">128</option> |
|
|
</select> |
|
|
</div> |
|
|
|
|
|
<div> |
|
|
<label class="block text-sm font-medium text-gray-700 mb-1">Base Frequency (θ)</label> |
|
|
<input type="range" id="baseFreq" min="1000" max="50000" step="1000" value="10000" class="w-full"> |
|
|
<div class="flex justify-between text-xs text-gray-500"> |
|
|
<span>1,000</span> |
|
|
<span id="baseFreqValue">10,000</span> |
|
|
<span>50,000</span> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<div> |
|
|
<label class="block text-sm font-medium text-gray-700 mb-1">Position (p)</label> |
|
|
<input type="range" id="position" min="0" max="1024" step="1" value="0" class="w-full"> |
|
|
<div class="flex justify-between text-xs text-gray-500"> |
|
|
<span>0</span> |
|
|
<span id="positionValue">0</span> |
|
|
<span>1024</span> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<div> |
|
|
<label class="block text-sm font-medium text-gray-700 mb-1">Relative Position (k)</label> |
|
|
<input type="range" id="relativePos" min="-32" max="32" step="1" value="1" class="w-full"> |
|
|
<div class="flex justify-between text-xs text-gray-500"> |
|
|
<span>-32</span> |
|
|
<span id="relativePosValue">1</span> |
|
|
<span>32</span> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<div class="pt-2"> |
|
|
<button id="animateBtn" class="w-full bg-indigo-600 text-white py-2 px-4 rounded-md hover:bg-indigo-700 transition-colors"> |
|
|
Animate Rotation |
|
|
</button> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<div class="mt-6 pt-4 border-t border-gray-200"> |
|
|
<h3 class="text-lg font-medium text-gray-800 mb-2">Dot Product Analysis</h3> |
|
|
<div class="space-y-2"> |
|
|
<div class="flex justify-between"> |
|
|
<span class="text-sm text-gray-600">q · k (original):</span> |
|
|
<span id="originalDot" class="font-mono">0.00</span> |
|
|
</div> |
|
|
<div class="flex justify-between"> |
|
|
<span class="text-sm text-gray-600">R(q,p) · R(k,p+k):</span> |
|
|
<span id="rotatedDot" class="font-mono">0.00</span> |
|
|
</div> |
|
|
<div class="flex justify-between"> |
|
|
<span class="text-sm text-gray-600">Difference:</span> |
|
|
<span id="dotDifference" class="font-mono">0.00</span> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
|
|
|
<div class="bg-white rounded-xl p-6 vector-container"> |
|
|
<h2 class="text-xl font-semibold mb-4 text-gray-800">3D Vector Rotation</h2> |
|
|
<div class="relative h-64 w-full mb-4"> |
|
|
<canvas id="vectorCanvas" class="absolute inset-0"></canvas> |
|
|
</div> |
|
|
<div class="text-sm text-gray-600"> |
|
|
<p>Visualization of how RoPE rotates vector components in 3D space. Each pair of vector dimensions is rotated by an angle determined by the position and frequency.</p> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
|
|
|
<div class="bg-white rounded-xl p-6"> |
|
|
<h2 class="text-xl font-semibold mb-4 text-gray-800">Frequency Spectrum</h2> |
|
|
<div class="relative h-64 w-full mb-4"> |
|
|
<canvas id="freqChart"></canvas> |
|
|
</div> |
|
|
<div class="text-sm text-gray-600"> |
|
|
<p>Shows the geometric progression of frequencies across dimensions. Lower dimensions have higher frequencies (shorter wavelengths) while higher dimensions have lower frequencies.</p> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
|
|
|
<div class="mt-6 bg-white rounded-xl p-6"> |
|
|
<h2 class="text-xl font-semibold mb-4 text-gray-800">Vector Component Rotation</h2> |
|
|
<div class="overflow-x-auto"> |
|
|
<table class="min-w-full divide-y divide-gray-200"> |
|
|
<thead class="bg-gray-50"> |
|
|
<tr> |
|
|
<th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Dimension Pair</th> |
|
|
<th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Original Values</th> |
|
|
<th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Rotated Values</th> |
|
|
<th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Rotation Angle</th> |
|
|
<th scope="col" class="px-6 py-3 text-left text-xs font-medium text-gray-500 uppercase tracking-wider">Frequency</th> |
|
|
</tr> |
|
|
</thead> |
|
|
<tbody id="vectorDetails" class="bg-white divide-y divide-gray-200"> |
|
|
|
|
|
</tbody> |
|
|
</table> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
|
|
|
<div class="mt-6 bg-white rounded-xl p-6"> |
|
|
<h2 class="text-xl font-semibold mb-4 text-gray-800">How RoPE Works</h2> |
|
|
<div class="prose max-w-none text-gray-700"> |
|
|
<p>Rotary Positional Embedding (RoPE) encodes position information by rotating pairs of vector components:</p> |
|
|
<ul class="list-disc pl-5 space-y-1"> |
|
|
<li>Each pair of dimensions (2i, 2i+1) is treated as a 2D vector</li> |
|
|
<li>The vector is rotated by an angle θ = p * ωᵢ where p is position and ωᵢ is frequency</li> |
|
|
<li>Frequencies decrease geometrically with dimension: ωᵢ = 1/(θ^(2i/d))</li> |
|
|
<li>This creates relative position encoding in the attention dot product: q·k depends only on m-n</li> |
|
|
<li>The norm (length) of vectors remains unchanged, preserving semantic information</li> |
|
|
</ul> |
|
|
<p class="mt-4">The rotation matrix for each pair is:</p> |
|
|
<div class="bg-gray-100 p-4 rounded-md overflow-x-auto"> |
|
|
<code class="text-sm"> |
|
|
Mᵢ = [ cos(pωᵢ) -sin(pωᵢ) ]<br> |
|
|
[ sin(pωᵢ) cos(pωᵢ) ] |
|
|
</code> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<script> |
|
|
|
|
|
document.addEventListener('DOMContentLoaded', function() { |
|
|
|
|
|
const dimensionSelect = document.getElementById('dimension'); |
|
|
const baseFreqSlider = document.getElementById('baseFreq'); |
|
|
const baseFreqValue = document.getElementById('baseFreqValue'); |
|
|
const positionSlider = document.getElementById('position'); |
|
|
const positionValue = document.getElementById('positionValue'); |
|
|
const relativePosSlider = document.getElementById('relativePos'); |
|
|
const relativePosValue = document.getElementById('relativePosValue'); |
|
|
const animateBtn = document.getElementById('animateBtn'); |
|
|
const originalDot = document.getElementById('originalDot'); |
|
|
const rotatedDot = document.getElementById('rotatedDot'); |
|
|
const dotDifference = document.getElementById('dotDifference'); |
|
|
const vectorDetails = document.getElementById('vectorDetails'); |
|
|
|
|
|
|
|
|
const vectorCanvas = document.getElementById('vectorCanvas'); |
|
|
const vectorCtx = vectorCanvas.getContext('2d'); |
|
|
vectorCanvas.width = vectorCanvas.offsetWidth; |
|
|
vectorCanvas.height = vectorCanvas.offsetHeight; |
|
|
|
|
|
|
|
|
const freqChartCanvas = document.getElementById('freqChart'); |
|
|
const freqChart = new Chart(freqChartCanvas, { |
|
|
type: 'line', |
|
|
data: { labels: [], datasets: [{ data: [], borderColor: '#4f46e5', tension: 0.1 }] }, |
|
|
options: { |
|
|
responsive: true, |
|
|
maintainAspectRatio: false, |
|
|
plugins: { legend: { display: false } }, |
|
|
scales: { |
|
|
y: { title: { display: true, text: 'Frequency (ωᵢ)' } }, |
|
|
x: { title: { display: true, text: 'Dimension Pair (i)' } } |
|
|
} |
|
|
} |
|
|
}); |
|
|
|
|
|
|
|
|
let state = { |
|
|
dimension: 64, |
|
|
baseFreq: 10000, |
|
|
position: 0, |
|
|
relativePos: 1, |
|
|
isAnimating: false, |
|
|
animationFrame: null |
|
|
}; |
|
|
|
|
|
|
|
|
dimensionSelect.addEventListener('change', updateDimension); |
|
|
baseFreqSlider.addEventListener('input', updateBaseFreq); |
|
|
positionSlider.addEventListener('input', updatePosition); |
|
|
relativePosSlider.addEventListener('input', updateRelativePos); |
|
|
animateBtn.addEventListener('click', toggleAnimation); |
|
|
|
|
|
|
|
|
updateAll(); |
|
|
|
|
|
|
|
|
window.addEventListener('resize', function() { |
|
|
vectorCanvas.width = vectorCanvas.offsetWidth; |
|
|
vectorCanvas.height = vectorCanvas.offsetHeight; |
|
|
renderVectorVisualization(); |
|
|
}); |
|
|
|
|
|
|
|
|
function updateDimension() { |
|
|
state.dimension = parseInt(dimensionSelect.value); |
|
|
updateAll(); |
|
|
} |
|
|
|
|
|
function updateBaseFreq() { |
|
|
state.baseFreq = parseInt(baseFreqSlider.value); |
|
|
baseFreqValue.textContent = state.baseFreq.toLocaleString(); |
|
|
updateAll(); |
|
|
} |
|
|
|
|
|
function updatePosition() { |
|
|
state.position = parseInt(positionSlider.value); |
|
|
positionValue.textContent = state.position; |
|
|
updateAll(); |
|
|
} |
|
|
|
|
|
function updateRelativePos() { |
|
|
state.relativePos = parseInt(relativePosSlider.value); |
|
|
relativePosValue.textContent = state.relativePos; |
|
|
updateAll(); |
|
|
} |
|
|
|
|
|
function toggleAnimation() { |
|
|
state.isAnimating = !state.isAnimating; |
|
|
animateBtn.textContent = state.isAnimating ? 'Stop Animation' : 'Animate Rotation'; |
|
|
|
|
|
if (state.isAnimating) { |
|
|
animate(); |
|
|
} else { |
|
|
cancelAnimationFrame(state.animationFrame); |
|
|
} |
|
|
} |
|
|
|
|
|
function animate() { |
|
|
if (!state.isAnimating) return; |
|
|
|
|
|
state.position = (state.position + 1) % 1024; |
|
|
positionSlider.value = state.position; |
|
|
positionValue.textContent = state.position; |
|
|
|
|
|
updateAll(); |
|
|
state.animationFrame = requestAnimationFrame(animate); |
|
|
} |
|
|
|
|
|
function updateAll() { |
|
|
renderVectorVisualization(); |
|
|
renderFrequencyChart(); |
|
|
renderVectorDetails(); |
|
|
calculateDotProducts(); |
|
|
} |
|
|
|
|
|
|
|
|
function renderVectorVisualization() { |
|
|
const ctx = vectorCtx; |
|
|
const width = vectorCanvas.width; |
|
|
const height = vectorCanvas.height; |
|
|
const centerX = width / 2; |
|
|
const centerY = height / 2; |
|
|
const scale = Math.min(width, height) * 0.3; |
|
|
|
|
|
|
|
|
ctx.clearRect(0, 0, width, height); |
|
|
|
|
|
|
|
|
ctx.strokeStyle = '#e5e7eb'; |
|
|
ctx.lineWidth = 1; |
|
|
|
|
|
|
|
|
ctx.beginPath(); |
|
|
ctx.moveTo(0, centerY); |
|
|
ctx.lineTo(width, centerY); |
|
|
ctx.stroke(); |
|
|
|
|
|
|
|
|
ctx.beginPath(); |
|
|
ctx.moveTo(centerX, 0); |
|
|
ctx.lineTo(centerX, height); |
|
|
ctx.stroke(); |
|
|
|
|
|
|
|
|
const vector = Array.from({length: state.dimension}, () => math.random(-1, 1)); |
|
|
|
|
|
|
|
|
const frequencies = calculateFrequencies(); |
|
|
|
|
|
|
|
|
const dim1 = 0; |
|
|
const dim2 = 1; |
|
|
const dim3 = 2; |
|
|
|
|
|
|
|
|
const x1 = vector[dim1]; |
|
|
const y1 = vector[dim2]; |
|
|
const z1 = vector[dim3]; |
|
|
|
|
|
|
|
|
const angle1 = state.position * frequencies[Math.floor(dim1/2)]; |
|
|
const angle2 = state.position * frequencies[Math.floor(dim2/2)]; |
|
|
|
|
|
const rotX1 = x1 * math.cos(angle1) - y1 * math.sin(angle1); |
|
|
const rotY1 = x1 * math.sin(angle1) + y1 * math.cos(angle1); |
|
|
const rotZ1 = z1; |
|
|
|
|
|
|
|
|
const project = (x, y, z) => { |
|
|
const perspective = 1 + z * 0.2; |
|
|
return { |
|
|
x: centerX + x * scale * perspective, |
|
|
y: centerY - y * scale * perspective |
|
|
}; |
|
|
}; |
|
|
|
|
|
|
|
|
const origProj = project(x1, y1, z1); |
|
|
ctx.strokeStyle = '#10b981'; |
|
|
ctx.lineWidth = 2; |
|
|
ctx.beginPath(); |
|
|
ctx.moveTo(centerX, centerY); |
|
|
ctx.lineTo(origProj.x, origProj.y); |
|
|
ctx.stroke(); |
|
|
|
|
|
|
|
|
const rotProj = project(rotX1, rotY1, rotZ1); |
|
|
ctx.strokeStyle = '#3b82f6'; |
|
|
ctx.lineWidth = 2; |
|
|
ctx.beginPath(); |
|
|
ctx.moveTo(centerX, centerY); |
|
|
ctx.lineTo(rotProj.x, rotProj.y); |
|
|
ctx.stroke(); |
|
|
|
|
|
|
|
|
ctx.fillStyle = '#111827'; |
|
|
ctx.font = '12px sans-serif'; |
|
|
ctx.fillText('Original', origProj.x + 5, origProj.y - 5); |
|
|
ctx.fillText('Rotated', rotProj.x + 5, rotProj.y + 15); |
|
|
|
|
|
|
|
|
ctx.fillStyle = '#6b7280'; |
|
|
ctx.font = '10px sans-serif'; |
|
|
ctx.fillText(`Dimensions ${dim1},${dim2}`, 10, 20); |
|
|
} |
|
|
|
|
|
function renderFrequencyChart() { |
|
|
const frequencies = calculateFrequencies(); |
|
|
const labels = frequencies.map((_, i) => i+1); |
|
|
|
|
|
freqChart.data.labels = labels; |
|
|
freqChart.data.datasets[0].data = frequencies; |
|
|
freqChart.update(); |
|
|
} |
|
|
|
|
|
function renderVectorDetails() { |
|
|
|
|
|
vectorDetails.innerHTML = ''; |
|
|
|
|
|
|
|
|
const frequencies = calculateFrequencies(); |
|
|
|
|
|
|
|
|
const vector = Array.from({length: state.dimension}, () => math.random(-1, 1)); |
|
|
|
|
|
|
|
|
const pairsToShow = Math.min(5, Math.floor(state.dimension/2)); |
|
|
|
|
|
for (let i = 0; i < pairsToShow; i++) { |
|
|
const dim1 = 2*i; |
|
|
const dim2 = 2*i + 1; |
|
|
|
|
|
|
|
|
const val1 = vector[dim1]; |
|
|
const val2 = vector[dim2]; |
|
|
|
|
|
|
|
|
const angle = state.position * frequencies[i]; |
|
|
|
|
|
|
|
|
const rotVal1 = val1 * math.cos(angle) - val2 * math.sin(angle); |
|
|
const rotVal2 = val1 * math.sin(angle) + val2 * math.cos(angle); |
|
|
|
|
|
|
|
|
const row = document.createElement('tr'); |
|
|
row.className = 'rope-vector'; |
|
|
|
|
|
row.innerHTML = ` |
|
|
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-900">${dim1}, ${dim2}</td> |
|
|
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-500">${val1.toFixed(3)}, ${val2.toFixed(3)}</td> |
|
|
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-blue-600">${rotVal1.toFixed(3)}, ${rotVal2.toFixed(3)}</td> |
|
|
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-900">${angle.toFixed(5)} rad</td> |
|
|
<td class="px-6 py-4 whitespace-nowrap text-sm font-mono text-gray-900">${frequencies[i].toExponential(3)}</td> |
|
|
`; |
|
|
|
|
|
vectorDetails.appendChild(row); |
|
|
} |
|
|
|
|
|
|
|
|
if (pairsToShow < Math.floor(state.dimension/2)) { |
|
|
const row = document.createElement('tr'); |
|
|
row.innerHTML = ` |
|
|
<td colspan="5" class="px-6 py-2 text-center text-sm text-gray-500">... ${Math.floor(state.dimension/2) - pairsToShow} more pairs ...</td> |
|
|
`; |
|
|
vectorDetails.appendChild(row); |
|
|
} |
|
|
} |
|
|
|
|
|
function calculateDotProducts() { |
|
|
|
|
|
const q = Array.from({length: state.dimension}, () => math.random(-1, 1)); |
|
|
const k = Array.from({length: state.dimension}, () => math.random(-1, 1)); |
|
|
|
|
|
|
|
|
const frequencies = calculateFrequencies(); |
|
|
|
|
|
|
|
|
const originalDotValue = dotProduct(q, k); |
|
|
|
|
|
|
|
|
const rotatedQ = applyRoPE(q, state.position, frequencies); |
|
|
const rotatedK = applyRoPE(k, state.position + state.relativePos, frequencies); |
|
|
|
|
|
|
|
|
const rotatedDotValue = dotProduct(rotatedQ, rotatedK); |
|
|
|
|
|
|
|
|
originalDot.textContent = originalDotValue.toFixed(4); |
|
|
rotatedDot.textContent = rotatedDotValue.toFixed(4); |
|
|
dotDifference.textContent = Math.abs(originalDotValue - rotatedDotValue).toFixed(4); |
|
|
} |
|
|
|
|
|
|
|
|
function calculateFrequencies() { |
|
|
const frequencies = []; |
|
|
const numPairs = Math.floor(state.dimension / 2); |
|
|
|
|
|
for (let i = 0; i < numPairs; i++) { |
|
|
frequencies.push(1 / Math.pow(state.baseFreq, (2 * i) / state.dimension)); |
|
|
} |
|
|
|
|
|
return frequencies; |
|
|
} |
|
|
|
|
|
function applyRoPE(vector, position, frequencies) { |
|
|
const rotated = [...vector]; |
|
|
|
|
|
for (let i = 0; i < Math.floor(vector.length / 2); i++) { |
|
|
const dim1 = 2 * i; |
|
|
const dim2 = 2 * i + 1; |
|
|
|
|
|
const val1 = vector[dim1]; |
|
|
const val2 = vector[dim2]; |
|
|
|
|
|
const angle = position * frequencies[i]; |
|
|
|
|
|
rotated[dim1] = val1 * math.cos(angle) - val2 * math.sin(angle); |
|
|
rotated[dim2] = val1 * math.sin(angle) + val2 * math.cos(angle); |
|
|
} |
|
|
|
|
|
return rotated; |
|
|
} |
|
|
|
|
|
function dotProduct(a, b) { |
|
|
return a.reduce((sum, val, i) => sum + val * b[i], 0); |
|
|
} |
|
|
}); |
|
|
</script> |
|
|
</body> |
|
|
</html> |