Per-request override of which token ids' logprobs to return.
See SamplingParams.logprob_token_ids.
Source code in vllm/v1/worker/gpu/sample/logprob.py
| class LogprobTokenIdsState:
"""Per-request override of which token ids' logprobs to return.
See `SamplingParams.logprob_token_ids`.
"""
def __init__(self, max_num_reqs: int, device: torch.device):
self.max_num_reqs = max_num_reqs
self.num_token_ids = UvaBackedTensor(max_num_reqs, dtype=torch.int32)
self.token_ids = StagedWriteTensor(
(max_num_reqs, MAX_LOGPROB_TOKEN_IDS),
dtype=torch.int32,
device=device,
)
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
token_ids = sampling_params.logprob_token_ids
if not token_ids:
self.num_token_ids.np[req_idx] = 0
return
n = len(token_ids)
if n > MAX_LOGPROB_TOKEN_IDS:
raise ValueError(
f"Too many logprob_token_ids: {n}. The max is {MAX_LOGPROB_TOKEN_IDS}."
)
self.num_token_ids.np[req_idx] = n
self.token_ids.stage_write(req_idx, 0, token_ids)
def apply_staged_writes(self) -> None:
self.num_token_ids.copy_to_uva()
self.token_ids.apply_write()
def max_num_token_ids(self, idx_mapping_np: np.ndarray) -> int:
return int(self.num_token_ids.np[idx_mapping_np].max(initial=0))
|