diff --git a/nanovllm/engine/llm_engine.py b/nanovllm/engine/llm_engine.py index ed4df26..2f97afe 100644 --- a/nanovllm/engine/llm_engine.py +++ b/nanovllm/engine/llm_engine.py @@ -18,6 +18,7 @@ class LLMEngine: config_fields = {field.name for field in fields(Config)} config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} config = Config(model, **config_kwargs) + Sequence.set_block_size(config.kvcache_block_size) self.ps = [] self.events = [] ctx = mp.get_context("spawn") diff --git a/nanovllm/engine/sequence.py b/nanovllm/engine/sequence.py index e12f30d..3168888 100644 --- a/nanovllm/engine/sequence.py +++ b/nanovllm/engine/sequence.py @@ -12,9 +12,13 @@ class SequenceStatus(Enum): class Sequence: - block_size = 256 + block_size: int = 0 # invalid value, will be set by set_block_size counter = count() + @classmethod + def set_block_size(cls, block_size: int): + cls.block_size = block_size + def __init__(self, token_ids: list[int], sampling_params = SamplingParams()): self.seq_id = next(Sequence.counter) self.status = SequenceStatus.WAITING