analyze-generative-diffusion-model
について
このスキルは、事前学習済みの拡散生成モデル(Stable Diffusionなど)を分析し、品質指標(FID、CLIPスコア)の計算、アテンションマップの可視化、潜在空間の調査を行います。モデル出力の品質評価、ノイズスケジュールの比較、テキスト条件付き生成におけるクロスアテンションパターンの分析などに活用できます。高度なモデル評価と検証を行う開発者向けに設計されています。
クイックインストール
Claude Code
推奨npx skills add pjt222/agent-almanac -a claude-code/plugin add https://github.com/pjt222/agent-almanacgit clone https://github.com/pjt222/agent-almanac.git ~/.claude/skills/analyze-generative-diffusion-modelこのコマンドをClaude Codeにコピー&ペーストしてスキルをインストールします
ドキュメント
Analyze a Generative Diffusion Model
Evaluate pre-trained generative diffusion models. Quantitative quality metrics, noise schedule inspection, cross-attention map analysis, latent space probing. Understand model behavior, diagnose failure modes, guide fine-tuning decisions.
When Use
- Evaluating pre-trained generative diffusion model's output quality with standard metrics
- Computing FID, IS, CLIP score, or precision/recall for generated image sets
- Inspecting and comparing noise schedules (linear, cosine, learned) via SNR curves
- Extracting cross-attention maps to understand text-to-image token-region correspondences
- Interpolating between latent codes or discovering semantic directions in latent space
- Detecting out-of-distribution inputs for diffusion model pipeline
Inputs
- Required: Pre-trained model identifier or checkpoint path (e.g.,
stabilityai/stable-diffusion-2-1) - Required: Analysis mode — one or more of:
metrics,schedule,attention,latent - Required: Reference dataset for metric computation (real images or dataset name)
- Optional: Text prompts for attention analysis (default: model-appropriate test prompts)
- Optional: Number of generated samples for metric computation (default: 10000)
- Optional: Device configuration (default:
cudaif available, elsecpu)
Steps
Step 1: Quantitative Evaluation
Compute standard generative quality metrics against reference dataset.
- Set up evaluation pipeline:
import torch
from diffusers import StableDiffusionPipeline
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16
).to(device)
fid = FrechetInceptionDistance(feature=2048, normalize=True).to(device)
inception = InceptionScore(normalize=True).to(device)
- Feed real images into metric accumulators:
from torch.utils.data import DataLoader
for batch in DataLoader(real_dataset, batch_size=64):
imgs = (batch * 255).byte().to(device)
fid.update(imgs, real=True)
- Generate samples and accumulate fake statistics:
prompts = load_evaluation_prompts("prompts.txt") # one prompt per line
n_generated = 0
while n_generated < 10000:
prompt_batch = prompts[n_generated:n_generated + 8]
images = pipe(prompt_batch, num_inference_steps=50).images
tensors = torch.stack([to_tensor(img) for img in images]).to(device)
byte_imgs = (tensors * 255).byte()
fid.update(byte_imgs, real=False)
inception.update(byte_imgs)
n_generated += len(images)
- Compute CLIP score for text-image alignment:
from torchmetrics.multimodal.clip_score import CLIPScore
clip_metric = CLIPScore(model_name_or_path="openai/clip-vit-large-patch14").to(device)
for prompt, image_tensor in zip(sampled_prompts, sampled_tensors):
clip_metric.update(image_tensor.unsqueeze(0), [prompt])
print(f"FID: {fid.compute():.2f}")
print(f"IS: {inception.compute()[0]:.2f} +/- {inception.compute()[1]:.2f}")
print(f"CLIP: {clip_metric.compute():.2f}")
- Compute precision and recall for mode coverage:
from torchmetrics.image import FrechetInceptionDistance
# Precision: fraction of generated images near real manifold
# Recall: fraction of real images near generated manifold
# Use improved precision/recall (Kynkaanniemi et al., 2019) via
# feature embeddings from the Inception network
Got: FID below 30 for well-trained Stable Diffusion model on standard benchmarks. IS above 50 on ImageNet-class prompts. CLIP score above 25 for text-conditioned models. Precision and recall both above 0.6.
If fail: FID above 100? Verify real and generated images share same resolution and normalization. CLIP score low but FID acceptable? Model generates plausible images that do not match text prompt -- check text encoder. Ensure at least 10,000 samples for stable FID estimates.
Step 2: Noise Schedule Inspection
Visualize and compare forward and reverse noise schedules.
- Extract schedule parameters from model:
scheduler = pipe.scheduler
betas = torch.tensor(scheduler.betas) if hasattr(scheduler, 'betas') else None
alphas_cumprod = torch.tensor(scheduler.alphas_cumprod)
timesteps = torch.arange(len(alphas_cumprod))
- Compute signal-to-noise ratio curve:
import numpy as np
import matplotlib.pyplot as plt
snr = alphas_cumprod / (1 - alphas_cumprod)
log_snr = torch.log(snr)
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
axes[0].plot(timesteps.numpy(), alphas_cumprod.numpy())
axes[0].set_xlabel("Timestep"); axes[0].set_ylabel("alpha_cumprod")
axes[0].set_title("Cumulative Signal Retention")
axes[1].plot(timesteps.numpy(), log_snr.numpy())
axes[1].set_xlabel("Timestep"); axes[1].set_ylabel("log(SNR)")
axes[1].set_title("Log Signal-to-Noise Ratio")
if betas is not None:
axes[2].plot(timesteps.numpy(), betas.numpy())
axes[2].set_xlabel("Timestep"); axes[2].set_ylabel("beta")
axes[2].set_title("Beta Schedule")
fig.tight_layout()
fig.savefig("noise_schedule.png", dpi=150)
- Compare multiple schedule types:
from diffusers import DDPMScheduler
schedules = {
"linear": DDPMScheduler(beta_schedule="linear", num_train_timesteps=1000),
"cosine": DDPMScheduler(beta_schedule="squaredcos_cap_v2", num_train_timesteps=1000),
}
fig, ax = plt.subplots(figsize=(10, 6))
for name, sched in schedules.items():
ac = torch.tensor(sched.alphas_cumprod)
snr = torch.log(ac / (1 - ac))
ax.plot(snr.numpy(), label=name)
ax.set_xlabel("Timestep"); ax.set_ylabel("log(SNR)")
ax.set_title("Schedule Comparison"); ax.legend()
fig.savefig("schedule_comparison.png", dpi=150)
Got: Cosine schedule shows more gradual SNR decrease in mid-timesteps compared to linear. Log-SNR curve should span from approximately +10 (clean) to -10 (pure noise). Learned schedules should be monotonically decreasing.
If fail: alphas_cumprod not monotonically decreasing? Schedule misconfigured. Values constant? Check scheduler properly initialized with model's config. For custom schedulers, verify set_timesteps() called.
Step 3: Attention Map Analysis
Extract and visualize cross-attention maps from text-conditioned models.
- Register attention hooks on U-Net cross-attention layers:
attention_maps = {}
def hook_fn(name):
def fn(module, input, output):
# Cross-attention: Q from image, K/V from text
if hasattr(module, 'processor'):
attention_maps[name] = output.detach().cpu()
return fn
for name, module in pipe.unet.named_modules():
if 'attn2' in name and hasattr(module, 'processor'):
module.register_forward_hook(hook_fn(name))
- Run inference and collect attention at specific timesteps:
prompt = "a red car parked next to a blue house"
timestep_attention = {}
# Custom callback to capture attention at specific timesteps
def callback_fn(pipe, step_index, timestep, callback_kwargs):
if step_index in [5, 15, 30, 45]:
timestep_attention[int(timestep)] = {
k: v.clone() for k, v in attention_maps.items()
}
return callback_kwargs
output = pipe(prompt, num_inference_steps=50, callback_on_step_end=callback_fn)
- Visualize token-region correspondences:
tokenizer = pipe.tokenizer
tokens = tokenizer.encode(prompt)
token_strings = [tokenizer.decode([t]) for t in tokens]
# Select a mid-resolution attention layer
layer_key = [k for k in attention_maps if 'mid' in k or 'up.1' in k][0]
attn = attention_maps[layer_key] # shape: (batch, heads, hw, seq_len)
attn_avg = attn.mean(dim=1) # average across heads
res = int(attn_avg.shape[1] ** 0.5)
attn_map = attn_avg[0].reshape(res, res, -1)
fig, axes = plt.subplots(2, min(len(token_strings), 6), figsize=(18, 6))
for idx, token in enumerate(token_strings[:6]):
for row, (ts, ts_attn) in enumerate(list(timestep_attention.items())[:2]):
a = ts_attn[layer_key].mean(dim=1)[0]
a_res = int(a.shape[0] ** 0.5)
axes[row, idx].imshow(a[:, idx].reshape(a_res, a_res), cmap="hot")
axes[row, idx].set_title(f"t={ts}: '{token}'")
axes[row, idx].axis("off")
fig.suptitle("Cross-Attention Maps by Token and Timestep")
fig.tight_layout()
fig.savefig("attention_maps.png", dpi=150)
Got: Content tokens ("car", "house") activate localized spatial regions. Style/color tokens ("red", "blue") activate regions overlapping with their associated object. Early timesteps (high noise) show diffuse attention; later timesteps show sharp, localized attention.
If fail: All attention maps look uniform? Hook may be capturing self-attention instead of cross-attention -- verify layer name contains attn2 (cross) not attn1 (self). Attention captured but has wrong dimensions? Check output tensor indexing matches layer's head count and spatial resolution.
Step 4: Latent Space Probing
Explore structure of latent space through interpolation and direction discovery.
- Encode reference images into latent space:
from diffusers import AutoencoderKL
from PIL import Image
import torchvision.transforms as T
vae = pipe.vae
transform = T.Compose([T.Resize(512), T.CenterCrop(512), T.ToTensor(),
T.Normalize([0.5], [0.5])])
def encode_image(image_path):
img = transform(Image.open(image_path).convert("RGB")).unsqueeze(0).to(device)
with torch.no_grad():
latent = vae.encode(img.half()).latent_dist.sample() * vae.config.scaling_factor
return latent
z1 = encode_image("image_a.png")
z2 = encode_image("image_b.png")
- Perform spherical linear interpolation (slerp):
def slerp(z1, z2, alpha):
"""Spherical linear interpolation between two latent codes."""
z1_flat = z1.flatten()
z2_flat = z2.flatten()
omega = torch.acos(torch.clamp(
torch.dot(z1_flat, z2_flat) / (z1_flat.norm() * z2_flat.norm()), -1, 1
))
if omega.abs() < 1e-6:
return (1 - alpha) * z1 + alpha * z2
return (torch.sin((1 - alpha) * omega) * z1 + torch.sin(alpha * omega) * z2) / torch.sin(omega)
alphas = torch.linspace(0, 1, 8)
interpolated = [slerp(z1, z2, a.item()) for a in alphas]
decoded = []
for z in interpolated:
with torch.no_grad():
img = vae.decode(z / vae.config.scaling_factor).sample
decoded.append(img.cpu())
- Discover semantic directions via prompt-pair differences:
def get_text_embedding(prompt):
tokens = pipe.tokenizer(prompt, return_tensors="pt", padding="max_length",
max_length=77, truncation=True).input_ids.to(device)
with torch.no_grad():
emb = pipe.text_encoder(tokens).last_hidden_state
return emb
pos_emb = get_text_embedding("a happy person smiling")
neg_emb = get_text_embedding("a sad person frowning")
direction = pos_emb - neg_emb # semantic direction in text embedding space
- Detect out-of-distribution latents:
# Compute latent space statistics from a reference set
ref_latents = torch.stack([encode_image(p) for p in reference_paths])
ref_mean = ref_latents.mean(dim=0)
ref_std = ref_latents.std(dim=0)
def ood_score(z):
"""Mahalanobis-like OOD score (higher = more unusual)."""
deviation = ((z - ref_mean) / (ref_std + 1e-6)).flatten()
return deviation.norm().item()
test_z = encode_image("test_image.png")
score = ood_score(test_z)
print(f"OOD score: {score:.2f} (reference mean: {np.mean([ood_score(r) for r in ref_latents]):.2f})")
Got: Interpolated images show smooth, semantically meaningful transitions without artifacts. Semantic directions produce consistent attribute changes when added to diverse latent codes. OOD scores for in-distribution images cluster tightly; outliers score significantly higher.
If fail: Interpolation produces blurry or incoherent midpoints? Use slerp instead of linear interpolation -- linear interpolation traverses low-density regions in high-dimensional latent spaces. Semantic directions have no visible effect? Increase direction magnitude or verify text encoder is same one used during model training.
Checks
- FID computed on at least 10,000 generated samples and matching real sample count
- CLIP score computed with same CLIP model used during training (if applicable)
- Noise schedule visualization shows monotonically decreasing alphas_cumprod
- Log-SNR spans approximately +10 to -10 across full timestep range
- Attention maps resolve per-token spatial activations at mid-resolution layers
- Attention sharpens from early (diffuse) to late (localized) timesteps
- Latent interpolations smooth with no sudden jumps or artifacts
- OOD detection baseline established from at least 100 reference samples
Pitfalls
- FID on mismatched resolutions: Real and generated images must be same resolution before feeding to Inception. Resize both sets identically or FID will be inflated.
- Forgetting to normalize for torchmetrics:
FrechetInceptionDistance(normalize=True)expects [0, 1] float tensors. Withnormalize=Falseit expects [0, 255] uint8. Mixing conventions gives meaningless FID. - Hooking self-attention instead of cross-attention: U-Net layers named
attn1are self-attention (image-to-image). Useattn2for cross-attention (text-to-image). Confusing them produces uninformative uniform maps. - Linear interpolation in high dimensions: Linear interpolation between two high-dimensional Gaussians passes through low-density shell. Always use slerp for latent space interpolation in diffusion models.
- Ignoring VAE scaling factor: Stable Diffusion latents scaled by
vae.config.scaling_factorafter encoding. Forgetting to apply or remove this factor produces garbled decoded images. - Too few samples for precision/recall: Precision and recall estimates from fewer than 5,000 samples per set unreliable. Use at least 10,000 for stable estimates.
See Also
implement-diffusion-network- building diffusion models that this skill evaluatesanalyze-diffusion-dynamics- mathematical foundations of noise processes inspected herefit-drift-diffusion-model- different diffusion model family sharing SDE foundations
GitHub リポジトリ
関連スキル
content-collections
メタこのスキルは、Content Collections(Markdown/MDXファイルを型安全なデータコレクションに変換するTypeScriptファーストのツール)の本番環境でテストされた設定を提供します。Zodバリデーションによる型安全性を実現し、ブログ、ドキュメントサイト、コンテンツ重視のVite + Reactアプリケーション構築時にご利用ください。Viteプラグインの設定、MDXコンパイルから、デプロイ最適化、スキーマバリデーションまで、すべてを網羅しています。
polymarket
メタこのスキルは、開発者がPolymarket予測市場プラットフォームを活用したアプリケーション構築を可能にします。API統合による取引や市場データの取得に加え、WebSocketを介したリアルタイムデータストリーミングにより、ライブ取引や市場活動を監視できます。取引戦略の実装や、ライブ市場更新を処理するツールの作成にご利用ください。
creating-opencode-plugins
メタこのスキルは、開発者がコマンド、ファイル、LSP操作など25種類以上のイベントタイプにフックするOpenCodeプラグインを作成することを支援します。JavaScript/TypeScriptモジュール向けに、プラグイン構造、イベントAPI仕様、および実装パターンを提供します。カスタムイベント駆動ロジックでOpenCode AIアシスタントのライフサイクルをインターセプト、監視、または拡張する必要がある場合にご利用ください。
sglang
メタSGLangは、高性能なLLMサービングフレームワークであり、RadixAttentionプレフィックスキャッシュを活用したJSON、正規表現、エージェントワークフロー向けの高速で構造化された生成を特長とします。特にプレフィックスが繰り返されるタスクにおいて、大幅に高速な推論を実現し、複雑な構造化出力やマルチターン対話に最適です。制約付きデコードが必要な場合や、広範なプレフィックス共有を伴うアプリケーションを構築する場合は、vLLMなどの代替案ではなくSGLangを選択してください。
