class GPUModelRunner(LoRAModelRunnerMixin):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.cache_config = vllm_config.cache_config
self.compilation_config = vllm_config.compilation_config
self.lora_config = vllm_config.lora_config
self.load_config = vllm_config.load_config
self.parallel_config = vllm_config.parallel_config
self.scheduler_config = vllm_config.scheduler_config
self.speculative_config = vllm_config.speculative_config
self.observability_config = vllm_config.observability_config
self.device = device
self.dtype = self.model_config.dtype
self.kv_cache_dtype = self.dtype
if self.cache_config.cache_dtype != "auto":
# Quantized KV cache.
self.kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
self.cache_config.cache_dtype
]
self.vocab_size = self.model_config.get_vocab_size()
self.max_model_len = self.model_config.max_model_len
self.max_num_tokens = self.scheduler_config.max_num_batched_tokens
self.max_num_reqs = self.scheduler_config.max_num_seqs
self.is_encoder_decoder = self.model_config.is_encoder_decoder
self.use_async_scheduling = self.scheduler_config.async_scheduling
self.output_copy_stream = torch.cuda.Stream(self.device)
# Pipeline parallelism.
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
self.is_first_pp_rank = get_pp_group().is_first_rank
self.is_last_pp_rank = get_pp_group().is_last_rank
# Persistent buffer for intermediate tensors (non-first PP ranks).
self.intermediate_tensors: IntermediateTensors | None = None
# Data parallelism.
self.dp_size = self.parallel_config.data_parallel_size
self.dp_rank = self.parallel_config.data_parallel_rank
# Decode context parallelism.
self.dcp_size = self.parallel_config.decode_context_parallel_size
self.use_dcp = self.dcp_size > 1
self.dcp_rank = get_dcp_group().rank_in_group if self.use_dcp else 0
self.cp_interleave = self.parallel_config.cp_kv_cache_interleave_size
# Multimodal
self.mm_registry = MULTIMODAL_REGISTRY
self.supports_mm_inputs = self.mm_registry.supports_multimodal_inputs(
self.model_config
)
self.encoder_cache = None
if self.supports_mm_inputs and self.is_first_pp_rank:
self.encoder_cache = EncoderCache()
# Speculative decoding.
self.speculator = None
self.num_speculative_steps = 0
self.use_aux_hidden_state_outputs = False
if self.speculative_config is not None:
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
if self.is_last_pp_rank:
self.speculator = init_speculator(self.vllm_config, self.device)
if self.speculative_config.method == "eagle3":
# EAGLE3 may require auxiliary hidden states from target model outputs.
self.use_aux_hidden_state_outputs = True
if self.use_pp:
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)
self.uniform_decode_query_len = 1 + self.num_speculative_steps
# Pooling models.
self.is_pooling_model = self.model_config.runner_type == "pooling"
self.pooling_runner: PoolingRunner | None = None
# General request states.
self.req_states = RequestState(
max_num_reqs=self.max_num_reqs,
max_model_len=self.max_model_len,
max_num_batched_tokens=self.max_num_tokens,
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
max_num_tokens=self.max_num_tokens,
device=self.device,
)
self.sampler: Sampler | None = None
self.rejection_sampler: RejectionSampler | None = None
self.prompt_logprobs_worker: PromptLogprobsWorker | None = None
self.structured_outputs_worker: StructuredOutputsWorker | None = None
if self.is_last_pp_rank and not self.is_pooling_model:
# Initialize sampling-related workers.
# These components are only set up on the last PP rank and
# for generative (non-pooling) models.
self.sampler = Sampler(
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
req_states=self.req_states,
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
if self.speculative_config is not None:
self.rejection_sampler = RejectionSampler(
self.sampler,
self.speculative_config,
self.device,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
vocab_size=self.vocab_size,
device=self.device,
)
# For CUDA graphs, and will init cudagraph_manager after init_attn_backend.
self.decode_query_len = self.num_speculative_steps + 1
self.cudagraph_manager: ModelCudaGraphManager | None = None
# LoRA-related workers.
self.lora_state = LoraState(max_num_reqs=self.max_num_reqs)
# KV Connector if configured.
self.kv_connector: KVConnector = NO_OP_KV_CONNECTOR
# For transferring state from execute_model to subsequent sample_tokens call.
self.execute_model_state: ExecuteModelState | None = None
# Expert parallelism load balancer.
self.eplb = EPLBController(self.parallel_config, self.device)
def update_max_model_len(self, max_model_len: int) -> None:
self.max_model_len = max_model_len
self.req_states.max_model_len = max_model_len
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
tasks: list[SupportedTask] = []
if self.model_config.runner_type == "generate":
tasks.extend(self.model_state.get_supported_generation_tasks())
if self.is_pooling_model:
# Do not rely on pooling_runner here, since this information is needed
# on the first PP rank, while pooling_runner is only initialized
# on the last PP rank.
tasks.extend(PoolingRunner.get_supported_tasks(self.model))
return tuple(tasks)
def load_model(self, load_dummy_weights: bool = False, *args, **kwargs) -> None:
time_before_load = time.perf_counter()
if load_dummy_weights:
self.load_config.load_format = "dummy"
self.eplb.prepare_load()
eplb_models_added = False
with DeviceMemoryProfiler() as m:
model_loader = get_model_loader(self.vllm_config.load_config)
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.vllm_config.model_config
)
if self.lora_config:
self.model = self.load_lora_model(
self.model, self.vllm_config, self.device
)
if self.use_aux_hidden_state_outputs:
assert self.speculative_config is not None
set_eagle3_aux_hidden_state_layers(self.model, self.speculative_config)
if self.speculator is not None:
self.speculator.load_model(self.model)
eplb_models_added = self.eplb.maybe_register_speculator(
self.speculator, self.speculative_config, load_dummy_weights
)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info(
"Model loading took %s GiB and %.6f seconds",
format_gib(m.consumed_memory),
time_after_load - time_before_load,
)
if not load_dummy_weights:
prepare_communication_buffer_for_model(self.model)
if self.speculator is not None:
prepare_communication_buffer_for_model(self.speculator.model)
# Initialize the components that require the model.
self.model_state = init_model_state(
self.vllm_config, self.model, self.encoder_cache, self.device
)
if self.is_pooling_model and self.is_last_pp_rank:
self.pooling_runner = PoolingRunner(self.model)
eplb_models_added |= self.eplb.maybe_register_model(
self.model,
self.model_config,
load_dummy_weights,
)
self.eplb.maybe_start_async_loop(eplb_models_added)
if not self.is_first_pp_rank:
# For non-first PP ranks, create intermediate tensors sized
# for the max capture size so they can be sliced per batch.
# Save as persistent member so runtime can copy received data
# into the same addresses that the CUDA graphs captured.
self.intermediate_tensors = self.model.make_empty_intermediate_tensors(
batch_size=self.max_num_tokens,
dtype=self.model_config.dtype,
device=self.device,
)
def get_model(self) -> nn.Module:
return self.model
@functools.cached_property
def main_stream(self) -> torch.cuda.Stream:
# Cache the default CUDA stream to avoid lookup overhead.
return torch.cuda.current_stream(self.device)
def get_kv_cache_spec(self):
return get_kv_cache_spec(self.vllm_config)
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
kv_cache_config = deepcopy(kv_cache_config)
self.kv_cache_config = kv_cache_config
block_sizes = [
kv_cache_group.kv_cache_spec.block_size
for kv_cache_group in kv_cache_config.kv_cache_groups
]
block_table_max_model_len = self.max_model_len
if self.is_encoder_decoder:
# Cross-attention block tables need to index encoder tokens
# (e.g., Whisper ~1500), which can exceed decoder max_model_len.
block_table_max_model_len = max(
block_table_max_model_len,
getattr(self.model_config.hf_config, "max_source_positions", 0),
)
self.block_tables = BlockTables(
block_sizes=block_sizes,
max_num_reqs=self.max_num_reqs,
max_num_batched_tokens=self.max_num_tokens,
max_model_len=block_table_max_model_len,
device=self.device,
cp_size=self.dcp_size,
cp_rank=self.dcp_rank,
cp_interleave=self.cp_interleave,
)
self.attn_backends, self.attn_groups, attn_cg_support = init_attn_backend(
self.kv_cache_config, self.vllm_config, self.device
)
initialize_mamba_ssu_backend(
self.vllm_config.mamba_config, self.kv_cache_config
)
cudagraph_mode = self.compilation_config.resolve_cudagraph_mode_and_sizes(
attn_cg_support.min_cg_support,
attn_cg_support.min_cg_attn_backend,
self.uniform_decode_query_len,
self.parallel_config.tensor_parallel_size,
self.kv_cache_config,
self.max_num_reqs,
)
self.cudagraph_manager = ModelCudaGraphManager(
self.vllm_config,
self.device,
cudagraph_mode,
decode_query_len=self.decode_query_len,
)
if self.speculator is not None:
self.speculator.init_cudagraph_manager(cudagraph_mode)
check_attention_cp_compatibility(self.vllm_config)
if self.speculator is not None:
# HACK(woosuk)
self.speculator.set_attn(
self.model_state,
self.kv_cache_config,
self.block_tables,
)
self.kv_caches: list[torch.Tensor] = []
kv_caches_dict = init_kv_cache(
self.kv_caches,
self.compilation_config.static_forward_context,
self.kv_cache_config,
self.attn_backends,
self.device,
self.cache_config.cache_dtype,
)
self.kv_connector = get_kv_connector(self.vllm_config, kv_caches_dict)
@torch.inference_mode()
@step_eplb_after(is_dummy=True)
def _dummy_run(
self,
num_tokens: int,
*args,
skip_attn: bool = False,
uniform_decode: bool = False,
skip_eplb: bool = False,
is_profile: bool = False,
**kwargs,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if skip_attn and not is_profile:
raise ValueError(
"skip_attn must only be True for initial memory profiling."
)
# Create a dummy scheduler output.
num_reqs = min(num_tokens, self.max_num_reqs)
if uniform_decode:
# HACK(lucas): for now since the worker is shared between MRV1 and MRV2,
# and for spec-decode with MTP we want to make sure the dummy runs use
# 1+num_speculative_tokens we use max here, this will likely be eventually
# changed in the worker: https://github.com/vllm-project/vllm/pull/35243
num_tokens = max(num_tokens, self.decode_query_len)
num_reqs = num_tokens // self.decode_query_len
assert num_tokens % self.decode_query_len == 0
num_tokens_per_request = [num_tokens // num_reqs] * num_reqs
num_tokens_per_request[-1] += num_tokens % num_reqs
assert sum(num_tokens_per_request) == num_tokens
num_scheduled_tokens = {
f"_dummy_req_{i}": n for i, n in enumerate(num_tokens_per_request)
}
dummy_scheduler_output = SchedulerOutput.make_empty()
dummy_scheduler_output.total_num_scheduled_tokens = num_tokens
dummy_scheduler_output.num_scheduled_tokens = num_scheduled_tokens
# Disable any use of KVConnector for dummy runs.
self.kv_connector.set_disabled(True)
# Get the intermediate tensors for the dummy run.
intermediate_tensors = None
if not self.is_first_pp_rank:
assert self.intermediate_tensors is not None
intermediate_tensors = self.intermediate_tensors[:num_tokens]
# Execute the model.
self.execute_model(
dummy_scheduler_output,
intermediate_tensors=intermediate_tensors,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
is_profile=is_profile,
)
self.kv_connector.set_disabled(False)
# Non-last PP ranks don't produce output for sampling.
if not self.is_last_pp_rank:
return None, None
assert self.execute_model_state is not None
input_batch = self.execute_model_state.input_batch
attn_metadata = self.execute_model_state.attn_metadata
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
hidden_states = self.execute_model_state.hidden_states
aux_hidden_states = self.execute_model_state.aux_hidden_states
self.execute_model_state = None
# dummy run the eagle speculator's propose to ensure DP/EP sync.
if self.speculator is not None:
assert self.sampler is not None
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
if self.speculator.supports_mm_inputs:
mm_inputs = (
[],
torch.zeros(
input_batch.num_tokens,
dtype=torch.bool,
device=self.device,
),
)
# Let the target override the hidden state fed to the drafter
# (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). The
# target returns a persistent buffer sized at max_num_batched_tokens;
# slice to the active token count that propose() expects.
spec_hidden_states = hidden_states
if hasattr(self.model, "get_mtp_target_hidden_states"):
pre_hc_hidden_states = self.model.get_mtp_target_hidden_states()
spec_hidden_states = pre_hc_hidden_states[: hidden_states.shape[0]] # type: ignore[union-attr]
self.speculator.propose(
input_batch=input_batch,
attn_metadata=attn_metadata,
slot_mappings=slot_mappings_by_layer,
last_hidden_states=spec_hidden_states,
aux_hidden_states=aux_hidden_states,
num_sampled=torch.ones(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
num_rejected=torch.zeros(
input_batch.num_reqs, dtype=torch.int32, device=self.device
),
last_sampled=self.req_states.last_sampled_tokens,
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
mm_inputs=mm_inputs,
is_profile=is_profile,
)
assert hidden_states is not None # Last PP rank always has hidden_states
sample_hidden_states = hidden_states[input_batch.logits_indices]
return hidden_states, sample_hidden_states
@torch.inference_mode()
def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
num_reqs = hidden_states.shape[0]
logits = self.model.compute_logits(hidden_states)
dummy_input_batch = InputBatch.make_dummy(
num_reqs, num_reqs, self.input_buffers
)
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
assert self.sampler is not None
self.sampler(logits, dummy_input_batch)
@torch.inference_mode()
def _dummy_pooler_run(self, hidden_states: torch.Tensor) -> None:
assert self.pooling_runner is not None
self.pooling_runner.dummy_pooler_run(hidden_states)
@torch.inference_mode()
def profile_run(self) -> None:
hidden_states, sample_hidden_states = self._dummy_run(
self.max_num_tokens, skip_attn=True, is_profile=True
)
# Only run sampler/pooler on last PP rank (non-last ranks return None).
if self.is_last_pp_rank:
assert sample_hidden_states is not None
if self.pooling_runner is None:
self._dummy_sampler_run(sample_hidden_states)
else:
self._dummy_pooler_run(hidden_states)
torch.accelerator.synchronize()
del hidden_states, sample_hidden_states
gc.collect()
def post_kv_cache_wake_up(self) -> None:
self.block_tables.init_block_table_layout_tensors()
def reset_mm_cache(self) -> None:
if self.encoder_cache is not None:
self.encoder_cache.reset_mm_cache()
def reset_encoder_cache(self) -> None:
if self.encoder_cache is not None:
self.encoder_cache.reset_encoder_cache()
def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int:
# SP is not supported yet.
return num_scheduled_tokens
def profile_cudagraph_memory(self) -> int:
# NOTE(woosuk): It is TBD whether we keep this API or not.
return 0
@torch.inference_mode()
def capture_model(self) -> int:
assert self.cudagraph_manager is not None
if not self.cudagraph_manager.needs_capture():
logger.warning(
"Skipping CUDA graph capture. To turn on CUDA graph capture, "
"ensure `cudagraph_mode` was not manually set to `NONE`"
)
return 0
compilation_counter.num_gpu_runner_capture_triggers += 1
start_time = time.perf_counter()
gc.collect()
torch.accelerator.empty_cache()
start_free_gpu_memory = torch.cuda.mem_get_info()[0]
with self.maybe_setup_dummy_loras(self.lora_config):
captured_attn_states = self.cudagraph_manager.capture(
self.model,
self.model_state,
self.input_buffers,
self.intermediate_tensors,
self.block_tables,
self.attn_groups,
self.kv_cache_config,
has_lora=self.lora_config is not None,
use_aux_hidden_state_outputs=self.use_aux_hidden_state_outputs,
)
if self.speculator is not None:
self.speculator.capture(captured_attn_states)
end_time = time.perf_counter()
end_free_gpu_memory = torch.cuda.mem_get_info()[0]
elapsed_time = end_time - start_time
cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory
# This usually takes 5~20 seconds.
logger.info(
"Graph capturing finished in %.0f secs, took %.2f GiB",
elapsed_time,
cuda_graph_size / (1 << 30),
)
return cuda_graph_size
def _remove_request(self, req_id: str) -> bool:
if not self.req_states.remove_request(req_id):
return False
if self.encoder_cache is not None:
self.encoder_cache.remove_request(req_id)
if self.prompt_logprobs_worker is not None:
self.prompt_logprobs_worker.remove_request(req_id)
self.lora_state.remove_request(req_id)
return True
def finish_requests(self, scheduler_output: SchedulerOutput) -> None:
finished_req_ids = scheduler_output.finished_req_ids
preempted_req_ids = scheduler_output.preempted_req_ids
if preempted_req_ids:
finished_req_ids = finished_req_ids.union(preempted_req_ids)
for req_id in finished_req_ids:
self._remove_request(req_id)
def free_states(self, scheduler_output: SchedulerOutput) -> None:
if self.encoder_cache is not None:
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.free_encoder_cache(mm_hash)
def add_requests(self, scheduler_output: SchedulerOutput) -> None:
for new_req_data in scheduler_output.scheduled_new_reqs:
assert new_req_data.prompt_token_ids is not None
assert new_req_data.prefill_token_ids is not None
req_id = new_req_data.req_id
# Streaming input update: request already exists from a prior
# chunk. Remove old state so it can be cleanly re-added below
# with the updated prompt_token_ids and mm_features.
self._remove_request(req_id)
prompt_len = len(new_req_data.prompt_token_ids)
self.req_states.add_request(
req_id=req_id,
prompt_len=prompt_len,
all_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
)
req_index = self.req_states.req_id_to_index[req_id]
if self.encoder_cache is not None:
self.encoder_cache.add_request(req_id, new_req_data.mm_features)
self.model_state.add_request(req_index, new_req_data)
self.block_tables.append_block_ids(
req_index, new_req_data.block_ids, overwrite=True
)
self.lora_state.add_request(req_id, req_index, new_req_data.lora_request)
if self.is_last_pp_rank and new_req_data.sampling_params is not None:
assert self.sampler is not None
self.sampler.add_request(
req_index, prompt_len, new_req_data.sampling_params
)
assert self.prompt_logprobs_worker is not None
self.prompt_logprobs_worker.add_request(
req_id, req_index, new_req_data.sampling_params
)
if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
self.model_state.apply_staged_writes()
if self.sampler is not None:
self.sampler.apply_staged_writes()
def update_requests(self, scheduler_output: SchedulerOutput) -> None:
# Add new blocks for the existing requests.
reqs = scheduler_output.scheduled_cached_reqs
for req_new_block_ids, req_id in zip(reqs.new_block_ids, reqs.req_ids):
if req_new_block_ids is not None:
req_index = self.req_states.req_id_to_index[req_id]
self.block_tables.append_block_ids(
req_index, req_new_block_ids, overwrite=False
)
def prepare_inputs(
self, scheduler_output: SchedulerOutput, batch_desc: BatchExecutionDescriptor
) -> InputBatch:
num_tokens = scheduler_output.total_num_scheduled_tokens
num_tokens_after_padding = batch_desc.num_tokens
assert num_tokens > 0
num_tokens_per_req = scheduler_output.num_scheduled_tokens
num_reqs = len(num_tokens_per_req)
# Decode first, then prefill.
# batch_idx -> req_id
req_ids = sorted(num_tokens_per_req, key=num_tokens_per_req.get) # type: ignore[arg-type]
numtoks_iter = map(num_tokens_per_req.get, req_ids)
num_scheduled_tokens = np.fromiter(numtoks_iter, dtype=np.int32, count=num_reqs)
idx_mapping_iter = map(self.req_states.req_id_to_index.get, req_ids)
idx_mapping_np = np.fromiter(idx_mapping_iter, dtype=np.int32, count=num_reqs)
idx_mapping = async_copy_to_gpu(idx_mapping_np, device=self.device)
# Get the number of draft tokens for each request.
draft_tokens = scheduler_output.scheduled_spec_decode_tokens
if not draft_tokens:
# No draft token scheduled (common case).
total_num_draft_tokens = 0
total_num_logits = num_reqs
cu_num_logits_np = np.arange(num_reqs + 1, dtype=np.int32)
cu_num_logits = torch.arange(
num_reqs + 1, device=self.device, dtype=torch.int32
)
expanded_idx_mapping = idx_mapping
expanded_local_pos = torch.zeros(
num_reqs, dtype=torch.int32, device=self.device
)
else:
num_draft_tokens = np.fromiter(
(len(draft_tokens.get(req_id, ())) for req_id in req_ids),
dtype=np.int32,
count=num_reqs,
)
total_num_draft_tokens = int(num_draft_tokens.sum())
total_num_logits = num_reqs + total_num_draft_tokens
num_logits = num_draft_tokens + 1
cu_num_logits_np = np.empty(num_reqs + 1, dtype=np.int32)
cu_num_logits_np[0] = 0
np.cumsum(num_logits, out=cu_num_logits_np[1:])
cu_num_logits = async_copy_to_gpu(cu_num_logits_np, device=self.device)
max_expand_len = self.num_speculative_steps + 1
expanded_idx_mapping, expanded_local_pos = expand_idx_mapping(
idx_mapping, total_num_logits, cu_num_logits, max_expand_len
)
# Get query_start_loc.
# num_reqs_padded is None for PIECEWISE graphs (no request padding needed)
num_reqs_padded = batch_desc.num_reqs or num_reqs
query_start_loc_np = np.empty(self.max_num_reqs + 1, dtype=np.int32)
query_start_loc_np[0] = 0
np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1 : num_reqs + 1])
# Pad for full CUDA graph mode.
# Some attention backends like FA3 require query_start_loc to be non-decreasing.
query_start_loc_np[num_reqs + 1 :] = num_tokens
async_copy_to_gpu(query_start_loc_np, out=self.input_buffers.query_start_loc)
query_start_loc_np = query_start_loc_np[: num_reqs_padded + 1]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs_padded + 1]
# Get prefill tokens if any.
if self.req_states.any_prefills(idx_mapping_np):
prepare_prefill_inputs(
self.input_buffers.input_ids,
self.req_states.next_prefill_tokens,
idx_mapping,
query_start_loc,
self.req_states.all_token_ids.gpu,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
)
# Prepare positions and seq_lens.
prepare_pos_seq_lens(
idx_mapping,
query_start_loc,
self.req_states.num_computed_tokens.gpu,
self.input_buffers.positions,
self.input_buffers.seq_lens,
)
seq_lens = self.input_buffers.seq_lens[:num_reqs_padded]
dcp_local_seq_lens = None
if self.use_dcp:
# Prepare dcp local seq_lens.
prepare_dcp_local_seq_lens(
self.input_buffers.dcp_local_seq_lens,
self.input_buffers.seq_lens,
num_reqs,
self.dcp_size,
self.dcp_rank,
self.cp_interleave,
)
dcp_local_seq_lens = self.input_buffers.dcp_local_seq_lens[:num_reqs_padded]
# Some input token ids are directly read from the last sampled tokens
# and draft tokens. Also, get the logits indices to sample tokens from.
logits_indices = combine_sampled_and_draft_tokens(
self.input_buffers.input_ids,
idx_mapping,
self.req_states.last_sampled_tokens,
query_start_loc,
seq_lens,
self.req_states.prefill_len.gpu,
self.req_states.draft_tokens,
cu_num_logits,
total_num_logits,
)
# CPU upper bound on seq_lens; padded entries left at zero.
seq_lens_cpu_upper_bound_np = np.zeros(num_reqs_padded, dtype=np.int32)
np.add(
self.req_states.num_computed_tokens_np[idx_mapping_np],
num_scheduled_tokens,
out=seq_lens_cpu_upper_bound_np[:num_reqs],
)
seq_lens_cpu_upper_bound = torch.from_numpy(seq_lens_cpu_upper_bound_np)
return InputBatch(
req_ids=req_ids,
num_reqs=num_reqs,
num_reqs_after_padding=num_reqs_padded,
idx_mapping=idx_mapping,
idx_mapping_np=idx_mapping_np,
expanded_idx_mapping=expanded_idx_mapping,
expanded_local_pos=expanded_local_pos,
num_scheduled_tokens=num_scheduled_tokens,
num_tokens=num_tokens,
num_tokens_after_padding=num_tokens_after_padding,
num_draft_tokens=total_num_draft_tokens,
query_start_loc=query_start_loc,
query_start_loc_np=query_start_loc_np,
seq_lens=seq_lens,
seq_lens_cpu_upper_bound=seq_lens_cpu_upper_bound,
dcp_local_seq_lens=dcp_local_seq_lens,
input_ids=self.input_buffers.input_ids[:num_tokens_after_padding],
positions=self.input_buffers.positions[:num_tokens_after_padding],
logits_indices=logits_indices,
cu_num_logits=cu_num_logits,
cu_num_logits_np=cu_num_logits_np,
has_structured_output_reqs=scheduler_output.has_structured_output_requests,
)
def prepare_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
# Block tables: num_kv_cache_groups x [num_reqs_padded, max_num_blocks].
block_tables = self.block_tables.gather_block_tables(
input_batch.idx_mapping,
num_reqs_padded=input_batch.num_reqs_after_padding,
)
# Slot mappings: [num_kv_cache_groups, num_tokens_padded].
# Kernel pads beyond num_tokens with PAD_SLOT_ID.
slot_mappings = self.block_tables.compute_slot_mappings(
input_batch.idx_mapping,
input_batch.query_start_loc,
input_batch.positions,
num_tokens_padded=input_batch.num_tokens_after_padding,
)
return block_tables, slot_mappings
def prepare_dummy_attn(
self, input_batch: InputBatch
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor]:
block_tables = self.block_tables.get_dummy_block_tables(input_batch.num_reqs)
slot_mappings = self.block_tables.get_dummy_slot_mappings(
input_batch.num_tokens
)
return block_tables, slot_mappings
def sample(
self,
hidden_states: torch.Tensor,
input_batch: InputBatch,
grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
assert self.structured_outputs_worker is not None
self.structured_outputs_worker.apply_grammar_bitmask(
logits,
input_batch,
grammar_output.structured_output_request_ids,
grammar_output.grammar_bitmask,
)
if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
assert self.sampler is not None
sampler_output = self.sampler(logits, input_batch)
else:
# Rejection sampling for spec decoding.
assert self.rejection_sampler is not None
assert self.speculator is not None
sampler_output = self.rejection_sampler(
logits,
input_batch,
# Draft logits are needed for probabilistic rejection sampling.
self.speculator.draft_logits,
)
# Get the number of sampled and rejected tokens.
# For chunked prefills, num_sampled and num_rejected are both 0.
num_sampled, num_rejected = get_num_sampled_and_rejected(
sampler_output.num_sampled,
input_batch.seq_lens,
input_batch.cu_num_logits,
input_batch.idx_mapping,
self.req_states.prefill_len.gpu,
)
return sampler_output, num_sampled, num_rejected
def postprocess(
self,
input_batch: InputBatch,
sampled_tokens: torch.Tensor,
num_sampled: torch.Tensor,
num_rejected: torch.Tensor,
) -> None:
# Update the number of computed tokens.
if self.is_last_pp_rank:
assert self.sampler is not None
output_bin_counts = self.sampler.penalties_state.output_bin_counts
else:
output_bin_counts = None
post_update(
input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu,
self.req_states.last_sampled_tokens,
output_bin_counts,
sampled_tokens,
num_sampled,
num_rejected,
input_batch.query_start_loc,
self.req_states.all_token_ids.gpu,
self.req_states.total_len.gpu,
)
# Update the number of computed prefill tokens.
idx_mapping_np = input_batch.idx_mapping_np
computed_prefill = self.req_states.num_computed_prefill_tokens
computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)
# Advance the CPU mirror optimistically (assume all scheduled accepted).
self.req_states.num_computed_tokens_np[idx_mapping_np] += (
input_batch.num_scheduled_tokens
)
@torch.inference_mode()
def execute_model(
self,
scheduler_output: SchedulerOutput,
intermediate_tensors: IntermediateTensors | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
is_profile: bool = False,
) -> ModelRunnerOutput | IntermediateTensors | None:
if not dummy_run:
# Update the request states.
self.finish_requests(scheduler_output)
self.free_states(scheduler_output)
self.add_requests(scheduler_output)
self.update_requests(scheduler_output)
self.block_tables.apply_staged_writes()
if scheduler_output.total_num_scheduled_tokens == 0:
# No need to run the model.
empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output
# Get batch descriptor and sync across DP ranks.
num_reqs = len(scheduler_output.num_scheduled_tokens)
num_toks = scheduler_output.total_num_scheduled_tokens
max_query_len = max(scheduler_output.num_scheduled_tokens.values())
uniform_tok_count = get_uniform_token_count(num_reqs, num_toks, max_query_len)
skip_compiled = False
if self.is_encoder_decoder and scheduler_output.scheduled_encoder_inputs:
# Encoder-decoder models such as Whisper should run eager/non-compiled
# when encoder inputs are scheduled, because this step updates
# cross-attention cache with dynamic encoder outputs.
skip_compiled = True
batch_desc, num_tokens_across_dp = dispatch_cg_and_sync_dp(
self.cudagraph_manager,
num_reqs,
num_toks,
uniform_tok_count,
self.dp_size,
self.dp_rank,
need_eager=is_profile or skip_compiled,
)
if batch_desc.num_tokens == 0:
# All DP ranks have zero tokens to run.
empty_output = self.kv_connector.no_forward(scheduler_output)
return empty_output
if not dummy_run:
# Common case.
# Prepare all the inputs and copy to the input buffers.
input_batch = self.prepare_inputs(scheduler_output, batch_desc)
block_tables, slot_mappings = self.prepare_attn(input_batch)
if self.lora_config:
# Activate LoRA adapters.
lora_inputs = self.lora_state.make_lora_inputs(
input_batch.req_ids,
input_batch.idx_mapping_np,
input_batch.num_scheduled_tokens,
)
self._set_active_loras(*lora_inputs)
else:
# No actual tokens to run. A dummy run for DP or memory profiling.
input_batch = InputBatch.make_dummy(
batch_desc.num_reqs or num_reqs,
batch_desc.num_tokens,
self.input_buffers,
)
if not skip_attn_for_dummy_run:
block_tables, slot_mappings = self.prepare_dummy_attn(input_batch)
else:
assert batch_desc.cg_mode != CUDAGraphMode.FULL, (
"Attention metadata must be prepared for dummy runs when using "
"FULL cudagraph mode."
)
block_tables = None
slot_mappings = None
# FIXME(woosuk): Fix warmup for LoRA.
attn_metadata = None
slot_mappings_by_layer = None
if not (dummy_run and skip_attn_for_dummy_run):
assert slot_mappings is not None
slot_mappings_by_layer = build_slot_mappings_by_layer(
slot_mappings, self.kv_cache_config
)
assert block_tables is not None
attn_metadata = self.model_state.prepare_attn(
input_batch,
batch_desc.cg_mode,
block_tables,
slot_mappings,
self.attn_groups,
self.kv_cache_config,
)
inputs_embeds = None
if self.supports_mm_inputs and self.is_first_pp_rank:
# Run MM encoder (if needed) and get multimodal embeddings.
# Only first PP rank prepares multimodal embeddings.
# NOTE(woosuk): We must call get_mm_embeddings even during dummy runs
# to obtain inputs_embeds, because the compiled model expects this input.
inputs_embeds = self.model_state.get_mm_embeddings(
scheduler_output.scheduled_encoder_inputs,
input_batch,
self.req_states,
)
model_inputs = {
"input_ids": input_batch.input_ids,
"positions": input_batch.positions,
"inputs_embeds": inputs_embeds,
# NOTE: Values returned by `prepare_inputs` will override the default
# values above.
**self.model_state.prepare_inputs(input_batch, self.req_states),
}
if not self.is_first_pp_rank:
# Update for non-first PP ranks.
model_inputs["input_ids"] = None
model_inputs["inputs_embeds"] = None
# Prepare the intermediate tensors.
assert intermediate_tensors is not None
assert self.intermediate_tensors is not None
n = input_batch.num_tokens_after_padding
model_inputs["intermediate_tensors"] = IntermediateTensors(
{
k: v[:n].copy_(intermediate_tensors.tensors[k][:n])
for k, v in self.intermediate_tensors.tensors.items()
}
)
del intermediate_tensors
# Run model.
if batch_desc.cg_mode == CUDAGraphMode.FULL:
# Use explicit cudagraph replay for FULL mode.
# NOTE(woosuk): Here, we don't need to pass the input tensors,
# because they are already copied to the CUDA graph input buffers.
assert self.cudagraph_manager is not None
self.kv_connector.pre_forward(scheduler_output)
model_output = self.cudagraph_manager.run_fullgraph(batch_desc)
else:
# For piecewise and eager mode, just call model().
batch_descriptor = BatchDescriptor(
num_tokens=input_batch.num_tokens_after_padding,
has_lora=self.lora_config is not None,
)
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=input_batch.num_tokens_after_padding,
cudagraph_runtime_mode=batch_desc.cg_mode,
num_tokens_across_dp=num_tokens_across_dp,
batch_descriptor=batch_descriptor,
slot_mapping=slot_mappings_by_layer,
skip_compiled=skip_compiled,
):
self.kv_connector.pre_forward(scheduler_output)
model_output = self.model(**model_inputs)
if self.is_last_pp_rank:
if self.use_aux_hidden_state_outputs:
assert isinstance(model_output, tuple)
hidden_states, aux_hidden_states = model_output
else:
assert isinstance(model_output, torch.Tensor)
hidden_states = model_output
aux_hidden_states = None
output_intermediate_tensors = None
else:
assert isinstance(model_output, IntermediateTensors)
hidden_states = None
aux_hidden_states = None
output_intermediate_tensors = model_output
kv_connector_output = self.kv_connector.post_forward(scheduler_output)
self.execute_model_state = ExecuteModelState(
input_batch=input_batch,
attn_metadata=attn_metadata,
slot_mappings_by_layer=slot_mappings_by_layer,
hidden_states=hidden_states,
aux_hidden_states=aux_hidden_states,
kv_connector_output=kv_connector_output,
)
if not self.is_last_pp_rank:
# Non-last PP rank: return IntermediateTensors for sending.
assert output_intermediate_tensors is not None
output_intermediate_tensors.kv_connector_output = kv_connector_output
return output_intermediate_tensors
return None
@torch.inference_mode()
@step_eplb_after()
def sample_tokens(
self, grammar_output: GrammarOutput | None
) -> AsyncOutput | ModelRunnerOutput | None:
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
input_batch = self.execute_model_state.input_batch
attn_metadata = self.execute_model_state.attn_metadata
slot_mappings_by_layer = self.execute_model_state.slot_mappings_by_layer
hidden_states = self.execute_model_state.hidden_states
aux_hidden_states = self.execute_model_state.aux_hidden_states
kv_connector_output = self.execute_model_state.kv_connector_output
self.execute_model_state = None
if not self.is_last_pp_rank:
# Non-last PP rank: hidden_states is None because this rank produced
# IntermediateTensors instead of final hidden states. Receive the
# sampled tokens broadcast from the last rank and update local state.
sampled, num_sampled, num_rejected = pp_receive(
input_batch.num_reqs, max_sample_len=self.num_speculative_steps + 1
)
self.postprocess(input_batch, sampled, num_sampled, num_rejected)
return None
# Last rank: sample tokens
sampler_output, num_sampled, num_rejected = self.sample(
hidden_states, input_batch, grammar_output
)
if self.use_pp:
# Broadcast to non-last PP ranks (handles spec decode multi-token).
pp_broadcast(sampler_output.sampled_token_ids, num_sampled, num_rejected)
assert self.prompt_logprobs_worker is not None
prompt_logprobs_dict = self.prompt_logprobs_worker.compute_prompt_logprobs(
self.model.compute_logits,
hidden_states,
input_batch,
self.req_states.all_token_ids.gpu,
self.req_states.num_computed_tokens.gpu,
self.req_states.prompt_len.np,
self.req_states.prefill_len.np,
self.req_states.num_computed_prefill_tokens,
)
# Prepare the model runner output.
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
# NOTE(woosuk): req_id_to_index is unused in this model runner.
# Only for compatibility with the existing model runner and scheduler.
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
sampled_token_ids=None, # type: ignore
prompt_logprobs_dict=prompt_logprobs_dict, # type: ignore[arg-type]
kv_connector_output=kv_connector_output,
)
async_output = AsyncOutput(
model_runner_output=model_runner_output,
sampler_output=sampler_output,
num_sampled_tokens=num_sampled,
main_stream=self.main_stream,
copy_stream=self.output_copy_stream,
)
mm_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
if self.speculator is not None and self.speculator.supports_mm_inputs:
# Get cached multimodal embeddings for draft forward.
# NOTE: This is done here because postprocess updates
# num_computed_prefill_tokens.
prefill_lens = self.req_states.prefill_len.np[input_batch.idx_mapping_np]
computed_prefill_lens = self.req_states.num_computed_prefill_tokens[
input_batch.idx_mapping_np
]
mm_inputs = self.model_state.encoder_runner.gather_mm_embeddings(
input_batch.req_ids,
input_batch.num_tokens,
input_batch.num_scheduled_tokens,
input_batch.query_start_loc_np,
prefill_lens,
computed_prefill_lens + 1, # +1 to consider the skew in eagle
)
# Postprocess results and update request states.
# NOTE: This is intentionally done after creating the AsyncOutput,
# ensuring that `copy_event` is recorded before calling postprocess.
# This sequencing may slightly reduce latency as async D2H copy does not
# need to wait for the postprocess to finish.
self.postprocess(
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
)
if self.speculator is not None:
assert self.sampler is not None
# Let the target override the hidden state fed to the drafter
# (e.g. DeepSeek V4 MTP needs the pre-hc_head residual). The
# target returns a persistent buffer sized at max_num_batched_tokens;
# slice to the active token count that propose() expects.
spec_hidden_states = hidden_states
if hasattr(self.model, "get_mtp_target_hidden_states"):
pre_hc_hidden_states = self.model.get_mtp_target_hidden_states()
spec_hidden_states = pre_hc_hidden_states[: hidden_states.shape[0]] # type: ignore[union-attr]
draft_tokens = self.speculator.propose(
input_batch,
attn_metadata,
slot_mappings_by_layer,
spec_hidden_states,
aux_hidden_states,
num_sampled,
num_rejected,
self.req_states.last_sampled_tokens,
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
mm_inputs=mm_inputs,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
self.draft_tokens_handler.set_draft_tokens(input_batch, draft_tokens)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
def take_draft_token_ids(self) -> DraftTokenIds | None:
return self.draft_tokens_handler.get_draft_tokens()
@torch.inference_mode()
@step_eplb_after()
def pool(self) -> AsyncPoolingOutput | ModelRunnerOutput | None:
if self.execute_model_state is None:
# The prior execute_model call must have failed.
return None
input_batch = self.execute_model_state.input_batch
hidden_states = self.execute_model_state.hidden_states
kv_connector_output = self.execute_model_state.kv_connector_output
self.execute_model_state = None
if not self.is_last_pp_rank:
self.postprocess_pool(input_batch)
return None
assert self.pooling_runner is not None
pooler_output, is_valid = self.pooling_runner.pool(
hidden_states, input_batch, self.req_states
)
# Build the model runner output.
model_runner_output = ModelRunnerOutput(
req_ids=input_batch.req_ids,
req_id_to_index={req_id: i for i, req_id in enumerate(input_batch.req_ids)},
kv_connector_output=kv_connector_output,
)
async_output = AsyncPoolingOutput(
model_runner_output=model_runner_output,
pooler_output=pooler_output,
is_valid=is_valid,
main_stream=self.main_stream,
copy_stream=self.output_copy_stream,
)
self.postprocess_pool(input_batch)
if self.use_async_scheduling:
return async_output
return async_output.get_output()
def postprocess_pool(self, input_batch: InputBatch) -> None:
# Update the number of computed tokens.
post_update_pool(
input_batch.idx_mapping,
self.req_states.num_computed_tokens.gpu,
input_batch.query_start_loc,
)
# Update the number of computed prefill tokens.
idx_mapping_np = input_batch.idx_mapping_np
computed_prefill = self.req_states.num_computed_prefill_tokens
computed_prefill[idx_mapping_np] += input_batch.num_scheduled_tokens
np.minimum(
computed_prefill, self.req_states.prefill_len.np, out=computed_prefill
)
# Advance the CPU mirror optimistically (assume all scheduled accepted).
self.req_states.num_computed_tokens_np[idx_mapping_np] += (
input_batch.num_scheduled_tokens
)
def shutdown(self) -> None:
"""Release GPU tensors (model weights, KV caches, workspace) so that
memory is reclaimable when running in the same process."""
torch.accelerator.synchronize()
if hasattr(self, "kv_caches"):
self.kv_caches.clear()
if hasattr(self, "attn_groups"):
self.attn_groups.clear()
if hasattr(self, "kv_cache_config"):
del self.kv_cache_config
free_before_shutdown(self.vllm_config)
if hasattr(self, "model"):
del self.model
gc.collect()
torch.accelerator.empty_cache()
logger.debug("Cleaned up model weights, KV caches, and workspace")
########### EPLB methods start ###########
@property
def eplb_state(self):
return self.eplb.state
@eplb_state.setter
def eplb_state(self, state) -> None:
self.eplb.state = state
@property
def eep_eplb_suppressed(self) -> bool:
return self.eplb.suppressed
@eep_eplb_suppressed.setter
def eep_eplb_suppressed(self, suppressed: bool) -> None:
self.eplb.suppressed = suppressed
def setup_eplb_from_mapping(
self,
expanded_physical_to_logical: torch.Tensor,
old_num_physical_experts: int,
) -> None:
self.eplb.setup_from_mapping(
self.model,
self.model_config,
expanded_physical_to_logical,
old_num_physical_experts,
)