support chunked prefill and fix minor bug

This commit is contained in:
GeekExplorer
2026-04-14 02:47:35 +08:00
parent 9e8507ef41
commit 8d63a98c03
8 changed files with 65 additions and 53 deletions
+3 -2
View File
@@ -23,7 +23,7 @@ class Qwen3Attention(nn.Module):
rms_norm_eps: float = 1e-06,
qkv_bias: bool = False,
rope_theta: float = 10000,
rope_scaling: tuple | None = None,
rope_scaling: dict | None = None,
) -> None:
super().__init__()
tp_size = dist.get_world_size()
@@ -51,12 +51,13 @@ class Qwen3Attention(nn.Module):
hidden_size,
bias=False,
)
if isinstance(rope_scaling, dict):
rope_theta = rope_scaling.get("rope_theta", rope_theta)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,