ucalyptus's picture
Update index.html
38a56bc verified
raw
history blame
18.4 kB
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>FlashInfer Attention State Visualization</title>
<script src="https://cdn.tailwindcss.com"></script>
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap" rel="stylesheet">
<style>
body {
font-family: 'Inter', sans-serif;
background-color: #f3f4f6; /* Tailwind gray-100 */
}
canvas {
background-color: #ffffff; /* White */
border-radius: 0.5rem; /* rounded-lg */
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1); /* shadow-md */
}
.btn {
padding: 0.5rem 1rem; /* py-2 px-4 */
border-radius: 0.375rem; /* rounded-md */
font-weight: 600; /* font-semibold */
color: white;
background-color: #4f46e5; /* indigo-600 */
transition: background-color 0.2s;
cursor: pointer;
margin: 0 0.5rem; /* mx-2 */
box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); /* shadow-sm */
}
.btn:hover {
background-color: #4338ca; /* indigo-700 */
}
.btn:disabled {
background-color: #a5b4fc; /* indigo-300 */
cursor: not-allowed;
}
.info-text {
color: #4b5563; /* gray-600 */
font-size: 0.875rem; /* text-sm */
margin-top: 0.5rem; /* mt-2 */
}
.state-box {
border: 2px solid;
border-radius: 0.375rem; /* rounded-md */
padding: 5px;
text-align: center;
font-size: 0.8rem;
background-color: rgba(255, 255, 255, 0.8);
}
.s-value {
font-weight: bold;
color: #1d4ed8; /* blue-700 */
}
.v-value {
display: inline-block;
width: 15px;
height: 15px;
border-radius: 50%;
margin-left: 5px;
vertical-align: middle;
}
</style>
</head>
<body class="flex flex-col items-center justify-center min-h-screen p-4">
<h1 class="text-2xl font-semibold text-gray-800 mb-4">FlashInfer: Attention States & Recursive Merge</h1>
<p class="text-center text-gray-600 mb-6 max-w-2xl">
This visualization demonstrates how FlashInfer computes attention by:
<br>1. Calculating partial "Attention States" (s, v) for subsets of Key-Value pairs.
<br>2. Recursively merging these states using the $\oplus$ operator to get the final result.
</p>
<canvas id="attentionCanvas" width="800" height="400"></canvas>
<p id="statusText" class="info-text h-6"></p> <div class="mt-6 flex justify-center">
<button id="resetBtn" class="btn">Reset & Initialize</button>
<button id="computeStatesBtn" class="btn" disabled>Compute Partial States</button>
<button id="mergeStatesBtn" class="btn" disabled>Merge States</button>
</div>
<script>
const canvas = document.getElementById('attentionCanvas');
const ctx = canvas.getContext('2d');
const statusText = document.getElementById('statusText');
const resetBtn = document.getElementById('resetBtn');
const computeStatesBtn = document.getElementById('computeStatesBtn');
const mergeStatesBtn = document.getElementById('mergeStatesBtn');
let animationFrameId = null;
// --- Configuration ---
const config = {
queryPos: { x: 50, y: 200 },
kvStartY: 50,
kvSpacingY: 60,
kvCount: 6, // Ensure this is even for easy partitioning
kvPartitionSize: 3, // Size of each partition
stateBoxWidth: 100,
stateBoxHeight: 50,
mergePointX: 600,
finalStateX: 700,
colors: {
query: '#ef4444', // red-500
kv: '#3b82f6', // blue-500
partition1: '#10b981', // emerald-500
partition2: '#f97316', // orange-500
merged: '#8b5cf6', // violet-500
arrow: '#6b7280', // gray-500
text: '#1f2937', // gray-800
},
arrowHeadSize: 8,
};
// --- State Variables ---
let query = {};
let kvPairs = [];
let states = {
partition1: null,
partition2: null,
final: null
};
let currentStep = 0; // 0: initial, 1: kvs shown, 2: states computed, 3: merged
// --- Drawing Functions ---
function drawArrow(startX, startY, endX, endY, color = config.colors.arrow, progress = 1) {
const dx = endX - startX;
const dy = endY - startY;
const length = Math.sqrt(dx * dx + dy * dy);
const angle = Math.atan2(dy, dx);
const currentX = startX + dx * progress;
const currentY = startY + dy * progress;
ctx.beginPath();
ctx.moveTo(startX, startY);
ctx.lineTo(currentX, currentY);
ctx.strokeStyle = color;
ctx.lineWidth = 1.5;
ctx.stroke();
if (progress >= 1) {
// Draw arrowhead
ctx.beginPath();
ctx.moveTo(currentX, currentY);
ctx.lineTo(currentX - config.arrowHeadSize * Math.cos(angle - Math.PI / 6), currentY - config.arrowHeadSize * Math.sin(angle - Math.PI / 6));
ctx.lineTo(currentX - config.arrowHeadSize * Math.cos(angle + Math.PI / 6), currentY - config.arrowHeadSize * Math.sin(angle + Math.PI / 6));
ctx.closePath();
ctx.fillStyle = color;
ctx.fill();
}
}
function drawQuery(q) {
ctx.fillStyle = config.colors.query;
ctx.fillRect(q.x - 15, q.y - 15, 30, 30);
ctx.fillStyle = config.colors.text;
ctx.font = 'bold 14px Inter';
ctx.textAlign = 'center';
ctx.fillText('Q', q.x, q.y + 5);
}
function drawKVPair(kv, index) {
const kvColor = config.colors.kv;
// Draw K
ctx.fillStyle = kvColor;
ctx.beginPath();
ctx.moveTo(kv.x - 10, kv.y - 10);
ctx.lineTo(kv.x, kv.y);
ctx.lineTo(kv.x - 10, kv.y + 10);
ctx.closePath();
ctx.fill();
// Draw V
ctx.beginPath();
ctx.arc(kv.x + 10, kv.y, 10, 0, Math.PI * 2);
ctx.fill();
ctx.fillStyle = config.colors.text;
ctx.font = '12px Inter';
ctx.textAlign = 'left';
ctx.fillText(`KV ${index + 1}`, kv.x + 25, kv.y + 4);
}
function drawAttentionState(state, color) {
if (!state) return;
ctx.strokeStyle = color;
ctx.lineWidth = 2;
ctx.fillStyle = 'rgba(255, 255, 255, 0.9)'; // Semi-transparent white background
ctx.fillRect(state.x, state.y, config.stateBoxWidth, config.stateBoxHeight);
ctx.strokeRect(state.x, state.y, config.stateBoxWidth, config.stateBoxHeight);
ctx.fillStyle = config.colors.text;
ctx.font = 'bold 12px Inter';
ctx.textAlign = 'center';
ctx.fillText(state.label, state.x + config.stateBoxWidth / 2, state.y + 15);
ctx.font = '11px Inter';
// Draw 's' value (LSE)
ctx.fillStyle = config.colors.text; // Use standard text color for label
ctx.fillText('s:', state.x + 25, state.y + 35);
ctx.fillStyle = config.colors.text; // Use standard text color for value
ctx.textAlign = 'left';
ctx.fillText(state.s.toFixed(2), state.x + 35, state.y + 35); // Display LSE value
// Draw 'v' representation
ctx.fillStyle = config.colors.text; // Use standard text color for label
ctx.textAlign = 'center';
ctx.fillText('v:', state.x + 70, state.y + 35);
ctx.fillStyle = state.vColor; // Use the state's specific color for the 'v' circle
ctx.beginPath();
ctx.arc(state.x + 85, state.y + 30, 7, 0, Math.PI * 2); // Draw circle representing 'v'
ctx.fill();
}
// --- Animation Loop ---
let progress = 0;
const animationSpeed = 0.02;
function animateStep(stepFunction, nextStep) {
cancelAnimationFrame(animationFrameId); // Cancel previous animation if any
progress = 0;
function loop() {
progress += animationSpeed;
if (progress >= 1) {
progress = 1;
stepFunction(progress); // Draw final frame
currentStep = nextStep;
updateButtons();
setStatus(''); // Clear status after animation
return;
}
stepFunction(progress); // Draw intermediate frame
animationFrameId = requestAnimationFrame(loop);
}
loop();
}
// --- Visualization Steps ---
function initialize() {
cancelAnimationFrame(animationFrameId); // Stop any ongoing animation
ctx.clearRect(0, 0, canvas.width, canvas.height);
currentStep = 0;
states = { partition1: null, partition2: null, final: null };
// Define Query
query = { ...config.queryPos };
// Define KV Pairs
kvPairs = [];
const kvStartX = config.queryPos.x + 100;
for (let i = 0; i < config.kvCount; i++) {
kvPairs.push({
x: kvStartX,
y: config.kvStartY + i * config.kvSpacingY,
id: i
});
}
// Initial Draw
drawQuery(query);
kvPairs.forEach((kv, i) => drawKVPair(kv, i));
currentStep = 1;
updateButtons();
setStatus('Initialized Query and KV Pairs.');
}
function computePartialStatesStep(p) {
ctx.clearRect(0, 0, canvas.width, canvas.height);
drawQuery(query);
kvPairs.forEach((kv, i) => drawKVPair(kv, i));
// --- Partition 1 ---
const state1X = config.queryPos.x + 250;
const state1Y = config.kvStartY + (config.kvPartitionSize / 2 - 0.5) * config.kvSpacingY - config.stateBoxHeight / 2;
if (!states.partition1) {
// Simulate LSE and create a representative color for v
states.partition1 = { x: state1X, y: state1Y, s: Math.random() * 5 + 5, vColor: config.colors.partition1, label: `State 1..${config.kvPartitionSize}` };
}
// Draw arrows from Q to KV (Partition 1)
for (let i = 0; i < config.kvPartitionSize; i++) {
drawArrow(query.x + 15, query.y, kvPairs[i].x - 15, kvPairs[i].y, config.colors.partition1, p);
}
// Draw arrows from KV to State (Partition 1)
if (p >= 0.5) { // Start drawing state arrows halfway through
const stateProgress = (p - 0.5) * 2;
for (let i = 0; i < config.kvPartitionSize; i++) {
drawArrow(kvPairs[i].x + 15, kvPairs[i].y, states.partition1.x, states.partition1.y + config.stateBoxHeight / 2, config.colors.partition1, stateProgress);
}
// Draw state box with fade-in effect (using alpha)
ctx.globalAlpha = stateProgress;
drawAttentionState(states.partition1, config.colors.partition1);
ctx.globalAlpha = 1.0; // Reset alpha
}
// --- Partition 2 ---
const state2X = state1X; // Align horizontally
const state2Y = config.kvStartY + (config.kvPartitionSize + config.kvPartitionSize / 2 - 0.5) * config.kvSpacingY - config.stateBoxHeight / 2;
if (!states.partition2) {
states.partition2 = { x: state2X, y: state2Y, s: Math.random() * 5 + 5, vColor: config.colors.partition2, label: `State ${config.kvPartitionSize+1}..${config.kvCount}` };
}
// Draw arrows from Q to KV (Partition 2)
for (let i = config.kvPartitionSize; i < config.kvCount; i++) {
drawArrow(query.x + 15, query.y, kvPairs[i].x - 15, kvPairs[i].y, config.colors.partition2, p);
}
// Draw arrows from KV to State (Partition 2)
if (p >= 0.5) {
const stateProgress = (p - 0.5) * 2;
for (let i = config.kvPartitionSize; i < config.kvCount; i++) {
drawArrow(kvPairs[i].x + 15, kvPairs[i].y, states.partition2.x, states.partition2.y + config.stateBoxHeight / 2, config.colors.partition2, stateProgress);
}
// Draw state box with fade-in
ctx.globalAlpha = stateProgress;
drawAttentionState(states.partition2, config.colors.partition2);
ctx.globalAlpha = 1.0;
}
}
function mergeStatesStep(p) {
// Redraw previous step completely first
computePartialStatesStep(1);
// --- Merge Operation ---
const mergePointY = canvas.height / 2;
const finalStateY = mergePointY - config.stateBoxHeight / 2;
if (!states.final) {
// Simulate merged state calculation: s = log(e^s1 + e^s2), v is weighted average
const s_final = Math.log(Math.exp(states.partition1.s) + Math.exp(states.partition2.s));
// Simple color mixing for v visualization
const finalVColor = averageHexColors(states.partition1.vColor, states.partition2.vColor);
states.final = { x: config.finalStateX, y: finalStateY, s: s_final, vColor: finalVColor, label: `Final State (1..${config.kvCount})` };
}
// Draw arrows from partial states to merge point
const mergeArrowEndX = config.mergePointX - 10; // End slightly before the symbol
drawArrow(states.partition1.x + config.stateBoxWidth, states.partition1.y + config.stateBoxHeight / 2, mergeArrowEndX, mergePointY, config.colors.partition1, p);
drawArrow(states.partition2.x + config.stateBoxWidth, states.partition2.y + config.stateBoxHeight / 2, mergeArrowEndX, mergePointY, config.colors.partition2, p);
// Draw merge symbol (⊕) - appears halfway
if (p >= 0.5) {
const symbolProgress = (p - 0.5) * 2;
ctx.globalAlpha = symbolProgress;
ctx.font = 'bold 30px Inter';
ctx.fillStyle = config.colors.merged;
ctx.textAlign = 'center';
ctx.fillText('⊕', config.mergePointX, mergePointY + 10); // Adjust Y for vertical centering
// Draw arrow from merge symbol to final state
const finalArrowStartX = config.mergePointX + 15; // Start after the symbol
drawArrow(finalArrowStartX, mergePointY, states.final.x, states.final.y + config.stateBoxHeight / 2, config.colors.merged, symbolProgress);
// Draw final state box with fade-in
drawAttentionState(states.final, config.colors.merged);
ctx.globalAlpha = 1.0; // Reset alpha
}
}
// --- Helper Functions ---
function setStatus(text) {
statusText.textContent = text;
}
function updateButtons() {
resetBtn.disabled = false;
computeStatesBtn.disabled = currentStep < 1 || currentStep >= 2;
mergeStatesBtn.disabled = currentStep < 2 || currentStep >= 3;
}
// Simple hex color averaging for visualization
function averageHexColors(color1, color2) {
const c1 = parseInt(color1.substring(1), 16);
const c2 = parseInt(color2.substring(1), 16);
const r1 = (c1 >> 16) & 255;
const g1 = (c1 >> 8) & 255;
const b1 = c1 & 255;
const r2 = (c2 >> 16) & 255;
const g2 = (c2 >> 8) & 255;
const b2 = c2 & 255;
const rAvg = Math.round((r1 + r2) / 2);
const gAvg = Math.round((g1 + g2) / 2);
const bAvg = Math.round((b1 + b2) / 2);
return `#${(1 << 24 | rAvg << 16 | gAvg << 8 | bAvg).toString(16).slice(1).padStart(6, '0')}`;
}
// --- Event Listeners ---
resetBtn.addEventListener('click', () => {
initialize();
});
computeStatesBtn.addEventListener('click', () => {
if (currentStep === 1) {
setStatus('Computing partial attention states...');
animateStep(computePartialStatesStep, 2);
}
});
mergeStatesBtn.addEventListener('click', () => {
if (currentStep === 2) {
setStatus('Merging attention states...');
animateStep(mergeStatesStep, 3);
}
});
// --- Initial Setup ---
window.onload = () => {
// Adjust canvas size slightly for high DPI if needed, but keep logical size
const dpr = window.devicePixelRatio || 1;
const rect = canvas.getBoundingClientRect();
// canvas.width = rect.width * dpr; // Keep logical size for layout
// canvas.height = rect.height * dpr;
// ctx.scale(dpr, dpr); // Scale context instead
initialize(); // Draw initial state on load
};
// Optional: Redraw on resize
window.addEventListener('resize', () => {
// Basic redraw based on current step - could be more sophisticated
if (currentStep === 1) initialize();
else if (currentStep === 2) computePartialStatesStep(1); // Draw completed step
else if (currentStep === 3) mergeStatesStep(1); // Draw completed step
});
</script>
</body>
</html>