|
<!DOCTYPE html> |
|
<html lang="en"> |
|
<head> |
|
<meta charset="UTF-8"> |
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"> |
|
<title>FlashInfer Attention State Visualization</title> |
|
<script src="https://cdn.tailwindcss.com"></script> |
|
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap" rel="stylesheet"> |
|
<style> |
|
body { |
|
font-family: 'Inter', sans-serif; |
|
background-color: #f3f4f6; |
|
} |
|
canvas { |
|
background-color: #ffffff; |
|
border-radius: 0.5rem; |
|
box-shadow: 0 4px 6px -1px rgb(0 0 0 / 0.1), 0 2px 4px -2px rgb(0 0 0 / 0.1); |
|
} |
|
.btn { |
|
padding: 0.5rem 1rem; |
|
border-radius: 0.375rem; |
|
font-weight: 600; |
|
color: white; |
|
background-color: #4f46e5; |
|
transition: background-color 0.2s; |
|
cursor: pointer; |
|
margin: 0 0.5rem; |
|
box-shadow: 0 1px 2px 0 rgb(0 0 0 / 0.05); |
|
} |
|
.btn:hover { |
|
background-color: #4338ca; |
|
} |
|
.btn:disabled { |
|
background-color: #a5b4fc; |
|
cursor: not-allowed; |
|
} |
|
.info-text { |
|
color: #4b5563; |
|
font-size: 0.875rem; |
|
margin-top: 0.5rem; |
|
} |
|
.state-box { |
|
border: 2px solid; |
|
border-radius: 0.375rem; |
|
padding: 5px; |
|
text-align: center; |
|
font-size: 0.8rem; |
|
background-color: rgba(255, 255, 255, 0.8); |
|
} |
|
.s-value { |
|
font-weight: bold; |
|
color: #1d4ed8; |
|
} |
|
.v-value { |
|
display: inline-block; |
|
width: 15px; |
|
height: 15px; |
|
border-radius: 50%; |
|
margin-left: 5px; |
|
vertical-align: middle; |
|
} |
|
</style> |
|
</head> |
|
<body class="flex flex-col items-center justify-center min-h-screen p-4"> |
|
|
|
<h1 class="text-2xl font-semibold text-gray-800 mb-4">FlashInfer: Attention States & Recursive Merge</h1> |
|
<p class="text-center text-gray-600 mb-6 max-w-2xl"> |
|
This visualization demonstrates how FlashInfer computes attention by: |
|
<br>1. Calculating partial "Attention States" (s, v) for subsets of Key-Value pairs. |
|
<br>2. Recursively merging these states using the $\oplus$ operator to get the final result. |
|
</p> |
|
|
|
<canvas id="attentionCanvas" width="800" height="400"></canvas> |
|
<p id="statusText" class="info-text h-6"></p> <div class="mt-6 flex justify-center"> |
|
<button id="resetBtn" class="btn">Reset & Initialize</button> |
|
<button id="computeStatesBtn" class="btn" disabled>Compute Partial States</button> |
|
<button id="mergeStatesBtn" class="btn" disabled>Merge States</button> |
|
</div> |
|
|
|
<script> |
|
const canvas = document.getElementById('attentionCanvas'); |
|
const ctx = canvas.getContext('2d'); |
|
const statusText = document.getElementById('statusText'); |
|
const resetBtn = document.getElementById('resetBtn'); |
|
const computeStatesBtn = document.getElementById('computeStatesBtn'); |
|
const mergeStatesBtn = document.getElementById('mergeStatesBtn'); |
|
|
|
let animationFrameId = null; |
|
|
|
|
|
const config = { |
|
queryPos: { x: 50, y: 200 }, |
|
kvStartY: 50, |
|
kvSpacingY: 60, |
|
kvCount: 6, |
|
kvPartitionSize: 3, |
|
stateBoxWidth: 100, |
|
stateBoxHeight: 50, |
|
mergePointX: 600, |
|
finalStateX: 700, |
|
colors: { |
|
query: '#ef4444', |
|
kv: '#3b82f6', |
|
partition1: '#10b981', |
|
partition2: '#f97316', |
|
merged: '#8b5cf6', |
|
arrow: '#6b7280', |
|
text: '#1f2937', |
|
}, |
|
arrowHeadSize: 8, |
|
}; |
|
|
|
|
|
let query = {}; |
|
let kvPairs = []; |
|
let states = { |
|
partition1: null, |
|
partition2: null, |
|
final: null |
|
}; |
|
let currentStep = 0; |
|
|
|
|
|
function drawArrow(startX, startY, endX, endY, color = config.colors.arrow, progress = 1) { |
|
const dx = endX - startX; |
|
const dy = endY - startY; |
|
const length = Math.sqrt(dx * dx + dy * dy); |
|
const angle = Math.atan2(dy, dx); |
|
|
|
const currentX = startX + dx * progress; |
|
const currentY = startY + dy * progress; |
|
|
|
ctx.beginPath(); |
|
ctx.moveTo(startX, startY); |
|
ctx.lineTo(currentX, currentY); |
|
ctx.strokeStyle = color; |
|
ctx.lineWidth = 1.5; |
|
ctx.stroke(); |
|
|
|
if (progress >= 1) { |
|
|
|
ctx.beginPath(); |
|
ctx.moveTo(currentX, currentY); |
|
ctx.lineTo(currentX - config.arrowHeadSize * Math.cos(angle - Math.PI / 6), currentY - config.arrowHeadSize * Math.sin(angle - Math.PI / 6)); |
|
ctx.lineTo(currentX - config.arrowHeadSize * Math.cos(angle + Math.PI / 6), currentY - config.arrowHeadSize * Math.sin(angle + Math.PI / 6)); |
|
ctx.closePath(); |
|
ctx.fillStyle = color; |
|
ctx.fill(); |
|
} |
|
} |
|
|
|
function drawQuery(q) { |
|
ctx.fillStyle = config.colors.query; |
|
ctx.fillRect(q.x - 15, q.y - 15, 30, 30); |
|
ctx.fillStyle = config.colors.text; |
|
ctx.font = 'bold 14px Inter'; |
|
ctx.textAlign = 'center'; |
|
ctx.fillText('Q', q.x, q.y + 5); |
|
} |
|
|
|
function drawKVPair(kv, index) { |
|
const kvColor = config.colors.kv; |
|
|
|
ctx.fillStyle = kvColor; |
|
ctx.beginPath(); |
|
ctx.moveTo(kv.x - 10, kv.y - 10); |
|
ctx.lineTo(kv.x, kv.y); |
|
ctx.lineTo(kv.x - 10, kv.y + 10); |
|
ctx.closePath(); |
|
ctx.fill(); |
|
|
|
ctx.beginPath(); |
|
ctx.arc(kv.x + 10, kv.y, 10, 0, Math.PI * 2); |
|
ctx.fill(); |
|
|
|
ctx.fillStyle = config.colors.text; |
|
ctx.font = '12px Inter'; |
|
ctx.textAlign = 'left'; |
|
ctx.fillText(`KV ${index + 1}`, kv.x + 25, kv.y + 4); |
|
} |
|
|
|
function drawAttentionState(state, color) { |
|
if (!state) return; |
|
ctx.strokeStyle = color; |
|
ctx.lineWidth = 2; |
|
ctx.fillStyle = 'rgba(255, 255, 255, 0.9)'; |
|
ctx.fillRect(state.x, state.y, config.stateBoxWidth, config.stateBoxHeight); |
|
ctx.strokeRect(state.x, state.y, config.stateBoxWidth, config.stateBoxHeight); |
|
|
|
ctx.fillStyle = config.colors.text; |
|
ctx.font = 'bold 12px Inter'; |
|
ctx.textAlign = 'center'; |
|
ctx.fillText(state.label, state.x + config.stateBoxWidth / 2, state.y + 15); |
|
|
|
ctx.font = '11px Inter'; |
|
|
|
ctx.fillStyle = config.colors.text; |
|
ctx.fillText('s:', state.x + 25, state.y + 35); |
|
ctx.fillStyle = config.colors.text; |
|
ctx.textAlign = 'left'; |
|
ctx.fillText(state.s.toFixed(2), state.x + 35, state.y + 35); |
|
|
|
|
|
ctx.fillStyle = config.colors.text; |
|
ctx.textAlign = 'center'; |
|
ctx.fillText('v:', state.x + 70, state.y + 35); |
|
ctx.fillStyle = state.vColor; |
|
ctx.beginPath(); |
|
ctx.arc(state.x + 85, state.y + 30, 7, 0, Math.PI * 2); |
|
ctx.fill(); |
|
} |
|
|
|
|
|
let progress = 0; |
|
const animationSpeed = 0.02; |
|
|
|
function animateStep(stepFunction, nextStep) { |
|
cancelAnimationFrame(animationFrameId); |
|
progress = 0; |
|
|
|
function loop() { |
|
progress += animationSpeed; |
|
if (progress >= 1) { |
|
progress = 1; |
|
stepFunction(progress); |
|
currentStep = nextStep; |
|
updateButtons(); |
|
setStatus(''); |
|
return; |
|
} |
|
|
|
stepFunction(progress); |
|
animationFrameId = requestAnimationFrame(loop); |
|
} |
|
loop(); |
|
} |
|
|
|
|
|
|
|
function initialize() { |
|
cancelAnimationFrame(animationFrameId); |
|
ctx.clearRect(0, 0, canvas.width, canvas.height); |
|
currentStep = 0; |
|
states = { partition1: null, partition2: null, final: null }; |
|
|
|
|
|
query = { ...config.queryPos }; |
|
|
|
|
|
kvPairs = []; |
|
const kvStartX = config.queryPos.x + 100; |
|
for (let i = 0; i < config.kvCount; i++) { |
|
kvPairs.push({ |
|
x: kvStartX, |
|
y: config.kvStartY + i * config.kvSpacingY, |
|
id: i |
|
}); |
|
} |
|
|
|
|
|
drawQuery(query); |
|
kvPairs.forEach((kv, i) => drawKVPair(kv, i)); |
|
currentStep = 1; |
|
updateButtons(); |
|
setStatus('Initialized Query and KV Pairs.'); |
|
} |
|
|
|
function computePartialStatesStep(p) { |
|
ctx.clearRect(0, 0, canvas.width, canvas.height); |
|
drawQuery(query); |
|
kvPairs.forEach((kv, i) => drawKVPair(kv, i)); |
|
|
|
|
|
const state1X = config.queryPos.x + 250; |
|
const state1Y = config.kvStartY + (config.kvPartitionSize / 2 - 0.5) * config.kvSpacingY - config.stateBoxHeight / 2; |
|
if (!states.partition1) { |
|
|
|
states.partition1 = { x: state1X, y: state1Y, s: Math.random() * 5 + 5, vColor: config.colors.partition1, label: `State 1..${config.kvPartitionSize}` }; |
|
} |
|
|
|
|
|
for (let i = 0; i < config.kvPartitionSize; i++) { |
|
drawArrow(query.x + 15, query.y, kvPairs[i].x - 15, kvPairs[i].y, config.colors.partition1, p); |
|
} |
|
|
|
if (p >= 0.5) { |
|
const stateProgress = (p - 0.5) * 2; |
|
for (let i = 0; i < config.kvPartitionSize; i++) { |
|
drawArrow(kvPairs[i].x + 15, kvPairs[i].y, states.partition1.x, states.partition1.y + config.stateBoxHeight / 2, config.colors.partition1, stateProgress); |
|
} |
|
|
|
ctx.globalAlpha = stateProgress; |
|
drawAttentionState(states.partition1, config.colors.partition1); |
|
ctx.globalAlpha = 1.0; |
|
} |
|
|
|
|
|
|
|
const state2X = state1X; |
|
const state2Y = config.kvStartY + (config.kvPartitionSize + config.kvPartitionSize / 2 - 0.5) * config.kvSpacingY - config.stateBoxHeight / 2; |
|
if (!states.partition2) { |
|
states.partition2 = { x: state2X, y: state2Y, s: Math.random() * 5 + 5, vColor: config.colors.partition2, label: `State ${config.kvPartitionSize+1}..${config.kvCount}` }; |
|
} |
|
|
|
|
|
for (let i = config.kvPartitionSize; i < config.kvCount; i++) { |
|
drawArrow(query.x + 15, query.y, kvPairs[i].x - 15, kvPairs[i].y, config.colors.partition2, p); |
|
} |
|
|
|
if (p >= 0.5) { |
|
const stateProgress = (p - 0.5) * 2; |
|
for (let i = config.kvPartitionSize; i < config.kvCount; i++) { |
|
drawArrow(kvPairs[i].x + 15, kvPairs[i].y, states.partition2.x, states.partition2.y + config.stateBoxHeight / 2, config.colors.partition2, stateProgress); |
|
} |
|
|
|
ctx.globalAlpha = stateProgress; |
|
drawAttentionState(states.partition2, config.colors.partition2); |
|
ctx.globalAlpha = 1.0; |
|
} |
|
} |
|
|
|
function mergeStatesStep(p) { |
|
|
|
computePartialStatesStep(1); |
|
|
|
|
|
const mergePointY = canvas.height / 2; |
|
const finalStateY = mergePointY - config.stateBoxHeight / 2; |
|
|
|
if (!states.final) { |
|
|
|
const s_final = Math.log(Math.exp(states.partition1.s) + Math.exp(states.partition2.s)); |
|
|
|
const finalVColor = averageHexColors(states.partition1.vColor, states.partition2.vColor); |
|
states.final = { x: config.finalStateX, y: finalStateY, s: s_final, vColor: finalVColor, label: `Final State (1..${config.kvCount})` }; |
|
} |
|
|
|
|
|
const mergeArrowEndX = config.mergePointX - 10; |
|
drawArrow(states.partition1.x + config.stateBoxWidth, states.partition1.y + config.stateBoxHeight / 2, mergeArrowEndX, mergePointY, config.colors.partition1, p); |
|
drawArrow(states.partition2.x + config.stateBoxWidth, states.partition2.y + config.stateBoxHeight / 2, mergeArrowEndX, mergePointY, config.colors.partition2, p); |
|
|
|
|
|
if (p >= 0.5) { |
|
const symbolProgress = (p - 0.5) * 2; |
|
ctx.globalAlpha = symbolProgress; |
|
ctx.font = 'bold 30px Inter'; |
|
ctx.fillStyle = config.colors.merged; |
|
ctx.textAlign = 'center'; |
|
ctx.fillText('⊕', config.mergePointX, mergePointY + 10); |
|
|
|
|
|
const finalArrowStartX = config.mergePointX + 15; |
|
drawArrow(finalArrowStartX, mergePointY, states.final.x, states.final.y + config.stateBoxHeight / 2, config.colors.merged, symbolProgress); |
|
|
|
|
|
drawAttentionState(states.final, config.colors.merged); |
|
ctx.globalAlpha = 1.0; |
|
} |
|
} |
|
|
|
|
|
function setStatus(text) { |
|
statusText.textContent = text; |
|
} |
|
|
|
function updateButtons() { |
|
resetBtn.disabled = false; |
|
computeStatesBtn.disabled = currentStep < 1 || currentStep >= 2; |
|
mergeStatesBtn.disabled = currentStep < 2 || currentStep >= 3; |
|
} |
|
|
|
|
|
function averageHexColors(color1, color2) { |
|
const c1 = parseInt(color1.substring(1), 16); |
|
const c2 = parseInt(color2.substring(1), 16); |
|
|
|
const r1 = (c1 >> 16) & 255; |
|
const g1 = (c1 >> 8) & 255; |
|
const b1 = c1 & 255; |
|
|
|
const r2 = (c2 >> 16) & 255; |
|
const g2 = (c2 >> 8) & 255; |
|
const b2 = c2 & 255; |
|
|
|
const rAvg = Math.round((r1 + r2) / 2); |
|
const gAvg = Math.round((g1 + g2) / 2); |
|
const bAvg = Math.round((b1 + b2) / 2); |
|
|
|
return `#${(1 << 24 | rAvg << 16 | gAvg << 8 | bAvg).toString(16).slice(1).padStart(6, '0')}`; |
|
} |
|
|
|
|
|
|
|
resetBtn.addEventListener('click', () => { |
|
initialize(); |
|
}); |
|
|
|
computeStatesBtn.addEventListener('click', () => { |
|
if (currentStep === 1) { |
|
setStatus('Computing partial attention states...'); |
|
animateStep(computePartialStatesStep, 2); |
|
} |
|
}); |
|
|
|
mergeStatesBtn.addEventListener('click', () => { |
|
if (currentStep === 2) { |
|
setStatus('Merging attention states...'); |
|
animateStep(mergeStatesStep, 3); |
|
} |
|
}); |
|
|
|
|
|
window.onload = () => { |
|
|
|
const dpr = window.devicePixelRatio || 1; |
|
const rect = canvas.getBoundingClientRect(); |
|
|
|
|
|
|
|
|
|
initialize(); |
|
}; |
|
|
|
|
|
window.addEventListener('resize', () => { |
|
|
|
if (currentStep === 1) initialize(); |
|
else if (currentStep === 2) computePartialStatesStep(1); |
|
else if (currentStep === 3) mergeStatesStep(1); |
|
}); |
|
|
|
</script> |
|
</body> |
|
</html> |
|
|