Skip to content

vllm.distributed.kv_transfer.kv_connector.v1.offloading.scheduler

OffloadingConnectorScheduler

Implementation of Scheduler side methods

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
class OffloadingConnectorScheduler:
    """Implementation of Scheduler side methods"""

    def __init__(self, spec: OffloadingSpec):
        self.config = SchedulerOffloadConfig.from_spec(spec)
        self.manager: OffloadingManager = spec.get_manager()

        full_attention_groups: list[int] = []
        sliding_window_groups: list[int] = []
        for group_config in self.config.kv_group_configs:
            if group_config.sliding_window_size_in_blocks is None:
                full_attention_groups.append(group_config.group_idx)
            else:
                sliding_window_groups.append(group_config.group_idx)

        # sort sliding window groups by window size in decreasing order
        def _sliding_window_sort_key(i: int) -> int:
            val = self.config.kv_group_configs[i].sliding_window_size_in_blocks
            assert val is not None
            return val

        sliding_window_groups.sort(key=_sliding_window_sort_key, reverse=True)

        # used by _lookup
        self._sliding_window_groups: tuple[int, ...] = tuple(sliding_window_groups)
        self._lookup_groups = tuple(full_attention_groups) + self._sliding_window_groups

        self._req_status: dict[ReqId, RequestOffloadState] = {}
        self._current_batch_load_jobs: dict[int, TransferJob] = {}
        self._current_batch_jobs_to_flush: set[int] = set()
        # if GPU prefix caching is enabled,
        # track loaded blocks to avoid redundant loads
        self._blocks_being_loaded: set[OffloadKey] | None = (
            set() if spec.vllm_config.cache_config.enable_prefix_caching else None
        )

        # Job ID counter shared by loads and stores.
        self._job_counter: int = 0
        self._jobs: dict[int, TransferJobStatus] = {}

        # block_id -> pending store job_ids. Used to track jobs that needs
        # flushing in case a block is re-allocated by the KV cache manager.
        # Populated only for finished requests (running-request blocks are
        # protected by their ref_cnt) and for sliding window blocks (which can
        # be freed before a request finishes).
        self._block_id_to_pending_jobs: dict[int, set[int]] = {}

    def _generate_job_id(self) -> int:
        job_id = self._job_counter
        self._job_counter += 1
        return job_id

    def _remove_pending_job(self, job_id: int, block_ids: list[int] | None) -> None:
        for bid in block_ids or ():
            pending = self._block_id_to_pending_jobs[bid]
            pending.remove(job_id)
            if not pending:
                del self._block_id_to_pending_jobs[bid]

    def _maximal_prefix_lookup(
        self, keys: Iterable[OffloadKey], req_context: ReqContext
    ) -> int | None:
        """Return the number of consecutive offloaded blocks from the start,
        or None if the backend deferred a lookup."""
        hit_count = 0
        defer_lookup = False
        for key in keys:
            result = self.manager.lookup(key, req_context)
            if result is None:
                defer_lookup = True
                # continue lookup to allow manager to kick-off async lookups
                # for all blocks (until a miss is detected)
                result = True
            if not result:
                break
            hit_count += 1
        return hit_count if not defer_lookup else None

    def _sliding_window_lookup(
        self,
        keys: Sequence[OffloadKey],
        sliding_window_size: int,
        req_context: ReqContext,
    ) -> int | None:
        """Return the end index (in `keys`) of the last run of
        `sliding_window_size` consecutive hits, scanning from the end.
        Returns 0 on miss, None if the backend deferred a lookup."""
        defer_lookup = False
        consecutive_hits = 0
        for idx in range(len(keys) - 1, -1, -1):
            result = self.manager.lookup(keys[idx], req_context)
            if result is None:
                defer_lookup = True
                # continue lookup to allow manager to kick-off async lookups
                # for all blocks (until a hit is detected)
                result = False
            if not result:
                consecutive_hits = 0
            else:
                consecutive_hits += 1
                if consecutive_hits == sliding_window_size:
                    return idx + sliding_window_size if not defer_lookup else None
        return consecutive_hits if not defer_lookup else None

    def _touch(self, req_status: RequestOffloadState):
        for group_config, group_state in zip(
            self.config.kv_group_configs, req_status.group_states
        ):
            if group_config.sliding_window_size_in_blocks is None:
                self.manager.touch(group_state.offload_keys)
            else:
                # we aim to keep just blocks that are necessary to hit
                # the original request (+ decoded blocks)
                blocks_to_skip = max(
                    0,
                    group_state.num_hit_blocks
                    - group_config.sliding_window_size_in_blocks,
                )
                self.manager.touch(group_state.offload_keys[blocks_to_skip:])

    def _lookup(self, req_status: RequestOffloadState) -> int | None:
        """
        Find how many tokens beyond num_locally_computed_tokens can be loaded.

        Iterates full-attention groups first (prefix lookup), then sliding-window
        groups (suffix lookup). Each group may tighten max_hit_size_tokens, which
        can invalidate an earlier group's result, so the loop re-runs when that
        happens until num_hit_tokens converges.
        """
        num_computed_tokens = req_status.num_locally_computed_tokens
        max_hit_size_tokens: int = req_status.req.num_tokens
        if self._sliding_window_groups:
            # the last prompt token has to be recomputed to get the logprobs
            # for sliding window attention, we must reduce by 1 to make sure
            # we still have a hit after reduction
            max_hit_size_tokens -= 1
        num_hit_tokens: int = 0
        defer_lookup = False
        lookup_groups = self._lookup_groups
        while lookup_groups:
            looked_up_sliding_window: bool = False
            groups_iter = iter(lookup_groups)
            lookup_groups = ()
            for group_idx in groups_iter:
                group_config: GroupOffloadConfig = self.config.kv_group_configs[
                    group_idx
                ]
                group_state: RequestGroupState = req_status.group_states[group_idx]
                offloaded_block_size = group_config.offloaded_block_size
                offload_keys = group_state.offload_keys

                assert (
                    len(offload_keys)
                    >= req_status.req.num_tokens // offloaded_block_size
                )

                # Constrain to block-aligned boundary for this group
                max_hit_size_tokens = min(
                    max_hit_size_tokens, len(offload_keys) * offloaded_block_size
                )
                if max_hit_size_tokens - num_computed_tokens < offloaded_block_size:
                    # we can only load less than a block, better skip
                    return 0

                num_blocks = min(
                    cdiv(max_hit_size_tokens, offloaded_block_size), len(offload_keys)
                )
                start_block_idx = num_computed_tokens // offloaded_block_size
                offload_keys = offload_keys[start_block_idx:num_blocks]
                sliding_window_size_in_blocks = (
                    group_config.sliding_window_size_in_blocks
                )

                # end index (in the sliced offload_keys) up to which we
                # have backend-confirmed hits
                num_hit_blocks: int | None
                if sliding_window_size_in_blocks is None:
                    num_hit_blocks = self._maximal_prefix_lookup(
                        offload_keys, req_status.req_context
                    )
                else:
                    num_hit_blocks = self._sliding_window_lookup(
                        offload_keys,
                        sliding_window_size_in_blocks,
                        req_status.req_context,
                    )
                if num_hit_blocks == 0:
                    return 0

                if num_hit_blocks is None:
                    defer_lookup = True
                else:
                    max_hit_size_tokens = min(
                        max_hit_size_tokens,
                        offloaded_block_size * (start_block_idx + num_hit_blocks),
                    )

                new_num_hit_tokens = max_hit_size_tokens - num_computed_tokens
                if new_num_hit_tokens < offloaded_block_size:
                    # we can only load less than a block, better skip
                    return 0

                if new_num_hit_tokens < num_hit_tokens:
                    if defer_lookup:
                        # make another iteration on all groups to check
                        # if we still need to defer lookup
                        defer_lookup = False
                        lookup_groups = self._lookup_groups
                    elif looked_up_sliding_window and not lookup_groups:
                        # we need another iteration to confirm previously looked up
                        # sliding window works with the new_num_hit_tokens
                        lookup_groups = self._sliding_window_groups

                looked_up_sliding_window |= sliding_window_size_in_blocks is not None
                num_hit_tokens = new_num_hit_tokens

        if defer_lookup:
            logger.debug(
                "Offloading manager delayed request %s as backend requested",
                req_status.req.request_id,
            )
            return None

        # possibly delay request if any of the hit blocks is already being loaded
        if self._blocks_being_loaded:
            for group_config, group_state in zip(
                self.config.kv_group_configs, req_status.group_states
            ):
                offloaded_block_size = group_config.offloaded_block_size
                sliding_window_size_in_blocks = (
                    group_config.sliding_window_size_in_blocks
                )
                offload_keys = group_state.offload_keys
                num_blocks = cdiv(
                    num_computed_tokens + num_hit_tokens, offloaded_block_size
                )
                start_block_idx = num_computed_tokens // offloaded_block_size
                offload_keys = offload_keys[start_block_idx:num_blocks]
                if sliding_window_size_in_blocks is not None:
                    offload_keys = offload_keys[-sliding_window_size_in_blocks:]
                if any(key in self._blocks_being_loaded for key in offload_keys):
                    # hit blocks are being loaded, delay request
                    logger.debug(
                        "Delaying request %s since some of its"
                        " blocks are already being loaded",
                        req_status.req.request_id,
                    )
                    return None

        logger.debug(
            "Request %s hit %s offloaded tokens after %s GPU hit tokens",
            req_status.req.request_id,
            num_hit_tokens,
            num_computed_tokens,
        )

        return num_hit_tokens

    def get_num_new_matched_tokens(
        self, request: Request, num_computed_tokens: int
    ) -> tuple[int | None, bool]:
        """
        Get number of new tokens that can be loaded beyond the
        num_computed_tokens.

        Args:
            request (Request): the request object.
            num_computed_tokens (int): the number of locally
                computed tokens for this request

        Returns:
            A tuple with the following elements:
                - The number of tokens that can be loaded beyond what is
                  already computed.
                  If None, it means that the connector needs more time to
                  determine the number of matched tokens, and the scheduler
                  should query for this request again later.
                - `True` if tokens will be loaded asynchronously
                  (between scheduler steps).
        """
        is_new_request = False
        if req_status := self._req_status.get(request.request_id):
            # make sure block IDs are cleared
            for group_state in req_status.group_states:
                group_state.block_ids.clear()
        else:
            is_new_request = True
            req_status = RequestOffloadState(config=self.config, req=request)
            self._req_status[request.request_id] = req_status

        req_status.update_offload_keys()
        req_status.num_locally_computed_tokens = num_computed_tokens

        num_hit_tokens = self._lookup(req_status)
        if is_new_request:
            req_status.update_num_hit_blocks(
                num_computed_tokens + (num_hit_tokens or 0)
            )

        self._touch(req_status)

        return num_hit_tokens, bool(num_hit_tokens)

    def update_state_after_alloc(
        self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int
    ):
        if num_external_tokens == 0:
            return

        req_status = self._req_status[request.request_id]

        num_locally_computed_tokens = req_status.num_locally_computed_tokens
        num_cached_tokens = num_locally_computed_tokens + num_external_tokens

        params = req_status.req_context.kv_transfer_params
        do_remote_decode = params is not None and params.get("do_remote_decode")

        keys_to_load: list[OffloadKey] = []
        dst_block_ids: list[int] = []
        # per group
        group_sizes: list[int] = []
        block_indices: list[int] = []
        for group_config, group_state, group_blocks in zip(
            self.config.kv_group_configs,
            req_status.group_states,
            blocks.blocks,
        ):
            gpu_block_size = group_config.gpu_block_size
            offloaded_block_size = group_config.offloaded_block_size
            offload_keys = group_state.offload_keys
            num_gpu_blocks = cdiv(num_cached_tokens, gpu_block_size)

            assert len(group_blocks) >= num_gpu_blocks
            num_locally_computed_gpu_blocks = num_gpu_blocks
            # Skip null placeholder blocks (used for sliding window or mamba padding).
            for i, block in enumerate(group_blocks[:num_gpu_blocks]):
                if not block.is_null and block.block_hash is None:
                    num_locally_computed_gpu_blocks = i
                    break

            assert (
                num_locally_computed_tokens
                <= num_locally_computed_gpu_blocks * gpu_block_size
            )
            num_pending_gpu_blocks = num_gpu_blocks - num_locally_computed_gpu_blocks

            if group_config.sliding_window_size_in_blocks is not None:
                assert (
                    num_pending_gpu_blocks
                    <= group_config.sliding_window_size_in_blocks
                    * self.config.block_size_factor
                )

            num_blocks = cdiv(num_cached_tokens, offloaded_block_size)
            assert len(offload_keys) >= num_blocks
            if num_pending_gpu_blocks:
                start_block_idx = (
                    num_locally_computed_gpu_blocks // self.config.block_size_factor
                )
                keys_to_load.extend(offload_keys[start_block_idx:num_blocks])

            dst_block_ids.extend(
                block.block_id
                for block in group_blocks[
                    num_locally_computed_gpu_blocks:num_gpu_blocks
                ]
            )
            group_sizes.append(num_pending_gpu_blocks)
            block_indices.append(num_locally_computed_gpu_blocks)

            if not do_remote_decode:
                # For P/D prefill requests (do_remote_decode=True), we do
                # NOT skip saving the hit prefix, as we need to stream the
                # entire KV cache so a remote decode node can consume it.
                group_state.next_stored_block_idx = num_blocks

        # Fence dst blocks against finished-request pending stores.
        if (
            self._block_id_to_pending_jobs
            and not self._block_id_to_pending_jobs.keys().isdisjoint(dst_block_ids)
        ):
            self._current_batch_jobs_to_flush.update(
                jid
                for bid in dst_block_ids
                for jid in self._block_id_to_pending_jobs.get(bid, ())
            )

        src_spec = self.manager.prepare_load(keys_to_load, req_status.req_context)
        dst_spec = GPULoadStoreSpec(
            dst_block_ids, group_sizes=group_sizes, block_indices=block_indices
        )

        load_job_id = self._generate_job_id()
        self._current_batch_load_jobs[load_job_id] = TransferJob(
            req_id=request.request_id,
            transfer_spec=(src_spec, dst_spec),
        )
        # a load can only be issued when no other jobs are pending.
        assert not req_status.transfer_jobs
        req_status.transfer_jobs.add(load_job_id)
        self._jobs[load_job_id] = TransferJobStatus(
            req_id=request.request_id,
            pending_count=self.config.num_workers,
            keys=set(keys_to_load),
            is_store=False,
        )

        if self._blocks_being_loaded is not None:
            self._blocks_being_loaded.update(keys_to_load)

    def _build_store_jobs(
        self,
        scheduler_output: SchedulerOutput,
    ) -> dict[int, TransferJob]:
        block_size_factor = self.config.block_size_factor
        store_jobs: dict[int, TransferJob] = {}
        # iterate over both new and cached requests
        for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output):
            req_status = self._req_status[req_id]
            req_status.update_offload_keys()
            req = req_status.req

            if preempted:
                for group_state in req_status.group_states:
                    group_state.block_ids.clear()

            if new_block_id_groups:
                req_status.update_block_id_groups(new_block_id_groups)
                # Fence new blocks against in-flight stores.
                if self._block_id_to_pending_jobs:
                    new_blocks_flat = [
                        bid for new_blocks in new_block_id_groups for bid in new_blocks
                    ]
                    if not self._block_id_to_pending_jobs.keys().isdisjoint(
                        new_blocks_flat
                    ):
                        self._current_batch_jobs_to_flush.update(
                            jid
                            for bid in new_blocks_flat
                            for jid in self._block_id_to_pending_jobs.get(bid, ())
                        )

            num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
            num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens
            # with async scheduling, some tokens may be missing
            num_offloadable_tokens = min(num_tokens_after_batch, req.num_tokens)

            # Filter out blocks skipped due to sliding window attention / SSM
            new_offload_keys: list[OffloadKey] = []
            for group_config, group_state in zip(
                self.config.kv_group_configs, req_status.group_states
            ):
                num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
                start_block_idx = group_state.next_stored_block_idx
                if num_blocks <= start_block_idx:
                    continue
                offload_keys = group_state.offload_keys[start_block_idx:num_blocks]
                # For each block to offload, take the last corresponding GPU block.
                # e.g. if block size factor is 3 and GPU block IDs are
                # 1 5 6 7 2 4 9 3 8 then we'll take blocks 6 4 8.
                # We will use these GPU blocks to determine if the block needs
                # offloading, or (if the GPU block ID is 0) this block should
                # be skipped due to sliding window attention / SSM.
                # We know that if a block is skipped, then all the previous blocks
                # are skipped as well. This is why we take the last of each block.
                offload_block_ids = group_state.block_ids[
                    start_block_idx * block_size_factor
                    + block_size_factor
                    - 1 : num_blocks * block_size_factor : block_size_factor
                ]
                assert len(offload_keys) == len(offload_block_ids)

                for offload_key, block_id in zip(offload_keys, offload_block_ids):
                    if block_id != 0:
                        new_offload_keys.append(offload_key)

            if not new_offload_keys:
                req_status.advance_stored_idx(num_offloadable_tokens)
                continue

            store_output = self.manager.prepare_store(
                new_offload_keys, req_status.req_context
            )
            if store_output is None:
                logger.warning("Request %s: cannot store blocks", req_id)
                continue

            if not store_output.keys_to_store:
                req_status.advance_stored_idx(num_offloadable_tokens)
                continue

            self._touch(req_status)

            keys_to_store = set(store_output.keys_to_store)

            group_sizes: list[int] = []
            block_indices: list[int] = []
            src_block_ids: list[int] = []
            sliding_window_block_ids: list[int] = []
            non_sliding_window_block_ids: list[int] = []
            for group_config, group_state in zip(
                self.config.kv_group_configs, req_status.group_states
            ):
                is_sliding_window = (
                    group_config.sliding_window_size_in_blocks is not None
                )
                num_blocks = num_offloadable_tokens // group_config.offloaded_block_size
                start_block_idx = group_state.next_stored_block_idx
                block_ids = group_state.block_ids
                num_group_blocks = 0
                start_gpu_block_idx: int | None = None
                for idx, offload_key in enumerate(
                    group_state.offload_keys[start_block_idx:num_blocks]
                ):
                    if offload_key not in keys_to_store:
                        continue

                    offloaded_block_idx = start_block_idx + idx
                    gpu_block_idx = offloaded_block_idx * block_size_factor
                    num_group_blocks += block_size_factor
                    for i in range(block_size_factor):
                        block_id = block_ids[gpu_block_idx + i]
                        if block_id == 0:
                            # skipped blocks cannot appear after non-skipped blocks
                            assert start_gpu_block_idx is None
                            continue
                        elif start_gpu_block_idx is None:
                            start_gpu_block_idx = gpu_block_idx + i
                        src_block_ids.append(block_id)
                        if is_sliding_window:
                            sliding_window_block_ids.append(block_id)
                        else:
                            non_sliding_window_block_ids.append(block_id)

                group_sizes.append(num_group_blocks)
                block_indices.append(start_gpu_block_idx or 0)
                group_state.next_stored_block_idx = num_blocks

            src_spec = GPULoadStoreSpec(
                src_block_ids, group_sizes=group_sizes, block_indices=block_indices
            )
            dst_spec = store_output.store_spec

            job_id = self._generate_job_id()
            # a store can only be issued when no load is pending.
            if req_status.transfer_jobs:
                any_jid = next(iter(req_status.transfer_jobs))
                assert self._jobs[any_jid].is_store
            req_status.transfer_jobs.add(job_id)

            # Watch sliding window blocks as they may get evicted
            # before the request finishes
            for bid in sliding_window_block_ids or ():
                self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id)

            # the non-sliding window blocks will be watched only
            # when the request finishes
            self._jobs[job_id] = TransferJobStatus(
                req_id=req_id,
                pending_count=self.config.num_workers,
                keys=set(keys_to_store),
                is_store=True,
                non_sliding_window_block_ids=non_sliding_window_block_ids,
                sliding_window_block_ids=sliding_window_block_ids or None,
            )

            store_jobs[job_id] = TransferJob(
                req_id=req_id, transfer_spec=(src_spec, dst_spec)
            )

            logger.debug(
                "Request %s offloading %s blocks upto %d tokens (job %d)",
                req_id,
                len(keys_to_store),
                num_offloadable_tokens,
                job_id,
            )

        return store_jobs

    def build_connector_meta(
        self, scheduler_output: SchedulerOutput
    ) -> KVConnectorMetadata:
        for req_id in scheduler_output.preempted_req_ids or ():
            req_status = self._req_status.get(req_id)
            if req_status is None or not req_status.transfer_jobs:
                continue
            any_jid = next(iter(req_status.transfer_jobs))
            assert self._jobs[any_jid].is_store
            self._current_batch_jobs_to_flush.update(req_status.transfer_jobs)

        meta = OffloadingConnectorMetadata(
            load_jobs=self._current_batch_load_jobs,
            store_jobs=self._build_store_jobs(scheduler_output),
            jobs_to_flush=self._current_batch_jobs_to_flush,
        )
        self._current_batch_load_jobs = {}
        self._current_batch_jobs_to_flush = set()
        return meta

    def update_connector_output(self, connector_output: KVConnectorOutput):
        """
        Update KVConnector state from worker-side connectors output.

        Args:
            connector_output (KVConnectorOutput): the worker-side
                connectors output.
        """
        meta = connector_output.kv_connector_worker_meta
        if not isinstance(meta, OffloadingWorkerMetadata):
            assert meta is None
            meta = OffloadingWorkerMetadata()
        for job_id, count in meta.completed_jobs.items():
            assert count > 0
            job_status = self._jobs[job_id]
            job_status.pending_count -= count
            if job_status.pending_count > 0:
                continue
            assert job_status.pending_count == 0

            if job_status.is_store:
                self.manager.complete_store(job_status.keys)
            else:
                self.manager.complete_load(job_status.keys)
                if self._blocks_being_loaded:
                    self._blocks_being_loaded.difference_update(job_status.keys)

            req_status = self._req_status[job_status.req_id]
            if self._block_id_to_pending_jobs:
                # Sliding window blocks are tracked from store creation
                # and must be cleaned up unconditionally.
                self._remove_pending_job(job_id, job_status.sliding_window_block_ids)
                # Non-sliding-window blocks are only tracked after
                # request_finished, so only clean up for finished requests.
                if req_status.req.is_finished():
                    self._remove_pending_job(
                        job_id, job_status.non_sliding_window_block_ids
                    )

            del self._jobs[job_id]
            req_status.transfer_jobs.remove(job_id)
            if not req_status.transfer_jobs and req_status.req.is_finished():
                del self._req_status[job_status.req_id]

    def request_finished(
        self,
        request: Request,
    ) -> tuple[bool, dict[str, Any] | None]:
        """
        Called when a request has finished, before its blocks are freed.

        Returns:
            True if the request is being saved/sent asynchronously and blocks
            should not be freed until the request_id is returned from
            get_finished().
            Optional KVTransferParams to be included in the request outputs
            returned by the engine.
        """
        # TODO(orozery): possibly kickoff offload for last block
        # which may have been deferred due to async scheduling
        req_status = self._req_status.get(request.request_id)
        if req_status is None:
            return False, None
        if not req_status.transfer_jobs:
            del self._req_status[request.request_id]
            return False, None
        # Pending stores will outlive the request's block ownership.
        # Register them so future block reuse triggers a flush.
        for job_id in req_status.transfer_jobs:
            job_status = self._jobs[job_id]
            for bid in job_status.non_sliding_window_block_ids or ():
                self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id)
        return False, None

    def take_events(self) -> Iterable[KVCacheEvent]:
        """Take the KV cache events from the connector.

        Returns:
            A list of KV cache events.
        """
        for event in self.manager.take_events():
            block_hashes = [get_offload_block_hash(key) for key in event.keys]
            if event.removed:
                yield BlockRemoved(block_hashes=block_hashes, medium=event.medium)
            else:
                yield BlockStored(
                    block_hashes=block_hashes,
                    parent_block_hash=None,
                    token_ids=[],
                    lora_id=None,
                    block_size=0,
                    medium=event.medium,
                    lora_name=None,
                )

    def shutdown(self) -> None:
        self.manager.shutdown()

_lookup

_lookup(req_status: RequestOffloadState) -> int | None

Find how many tokens beyond num_locally_computed_tokens can be loaded.

Iterates full-attention groups first (prefix lookup), then sliding-window groups (suffix lookup). Each group may tighten max_hit_size_tokens, which can invalidate an earlier group's result, so the loop re-runs when that happens until num_hit_tokens converges.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def _lookup(self, req_status: RequestOffloadState) -> int | None:
    """
    Find how many tokens beyond num_locally_computed_tokens can be loaded.

    Iterates full-attention groups first (prefix lookup), then sliding-window
    groups (suffix lookup). Each group may tighten max_hit_size_tokens, which
    can invalidate an earlier group's result, so the loop re-runs when that
    happens until num_hit_tokens converges.
    """
    num_computed_tokens = req_status.num_locally_computed_tokens
    max_hit_size_tokens: int = req_status.req.num_tokens
    if self._sliding_window_groups:
        # the last prompt token has to be recomputed to get the logprobs
        # for sliding window attention, we must reduce by 1 to make sure
        # we still have a hit after reduction
        max_hit_size_tokens -= 1
    num_hit_tokens: int = 0
    defer_lookup = False
    lookup_groups = self._lookup_groups
    while lookup_groups:
        looked_up_sliding_window: bool = False
        groups_iter = iter(lookup_groups)
        lookup_groups = ()
        for group_idx in groups_iter:
            group_config: GroupOffloadConfig = self.config.kv_group_configs[
                group_idx
            ]
            group_state: RequestGroupState = req_status.group_states[group_idx]
            offloaded_block_size = group_config.offloaded_block_size
            offload_keys = group_state.offload_keys

            assert (
                len(offload_keys)
                >= req_status.req.num_tokens // offloaded_block_size
            )

            # Constrain to block-aligned boundary for this group
            max_hit_size_tokens = min(
                max_hit_size_tokens, len(offload_keys) * offloaded_block_size
            )
            if max_hit_size_tokens - num_computed_tokens < offloaded_block_size:
                # we can only load less than a block, better skip
                return 0

            num_blocks = min(
                cdiv(max_hit_size_tokens, offloaded_block_size), len(offload_keys)
            )
            start_block_idx = num_computed_tokens // offloaded_block_size
            offload_keys = offload_keys[start_block_idx:num_blocks]
            sliding_window_size_in_blocks = (
                group_config.sliding_window_size_in_blocks
            )

            # end index (in the sliced offload_keys) up to which we
            # have backend-confirmed hits
            num_hit_blocks: int | None
            if sliding_window_size_in_blocks is None:
                num_hit_blocks = self._maximal_prefix_lookup(
                    offload_keys, req_status.req_context
                )
            else:
                num_hit_blocks = self._sliding_window_lookup(
                    offload_keys,
                    sliding_window_size_in_blocks,
                    req_status.req_context,
                )
            if num_hit_blocks == 0:
                return 0

            if num_hit_blocks is None:
                defer_lookup = True
            else:
                max_hit_size_tokens = min(
                    max_hit_size_tokens,
                    offloaded_block_size * (start_block_idx + num_hit_blocks),
                )

            new_num_hit_tokens = max_hit_size_tokens - num_computed_tokens
            if new_num_hit_tokens < offloaded_block_size:
                # we can only load less than a block, better skip
                return 0

            if new_num_hit_tokens < num_hit_tokens:
                if defer_lookup:
                    # make another iteration on all groups to check
                    # if we still need to defer lookup
                    defer_lookup = False
                    lookup_groups = self._lookup_groups
                elif looked_up_sliding_window and not lookup_groups:
                    # we need another iteration to confirm previously looked up
                    # sliding window works with the new_num_hit_tokens
                    lookup_groups = self._sliding_window_groups

            looked_up_sliding_window |= sliding_window_size_in_blocks is not None
            num_hit_tokens = new_num_hit_tokens

    if defer_lookup:
        logger.debug(
            "Offloading manager delayed request %s as backend requested",
            req_status.req.request_id,
        )
        return None

    # possibly delay request if any of the hit blocks is already being loaded
    if self._blocks_being_loaded:
        for group_config, group_state in zip(
            self.config.kv_group_configs, req_status.group_states
        ):
            offloaded_block_size = group_config.offloaded_block_size
            sliding_window_size_in_blocks = (
                group_config.sliding_window_size_in_blocks
            )
            offload_keys = group_state.offload_keys
            num_blocks = cdiv(
                num_computed_tokens + num_hit_tokens, offloaded_block_size
            )
            start_block_idx = num_computed_tokens // offloaded_block_size
            offload_keys = offload_keys[start_block_idx:num_blocks]
            if sliding_window_size_in_blocks is not None:
                offload_keys = offload_keys[-sliding_window_size_in_blocks:]
            if any(key in self._blocks_being_loaded for key in offload_keys):
                # hit blocks are being loaded, delay request
                logger.debug(
                    "Delaying request %s since some of its"
                    " blocks are already being loaded",
                    req_status.req.request_id,
                )
                return None

    logger.debug(
        "Request %s hit %s offloaded tokens after %s GPU hit tokens",
        req_status.req.request_id,
        num_hit_tokens,
        num_computed_tokens,
    )

    return num_hit_tokens

_maximal_prefix_lookup

_maximal_prefix_lookup(
    keys: Iterable[OffloadKey], req_context: ReqContext
) -> int | None

Return the number of consecutive offloaded blocks from the start, or None if the backend deferred a lookup.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def _maximal_prefix_lookup(
    self, keys: Iterable[OffloadKey], req_context: ReqContext
) -> int | None:
    """Return the number of consecutive offloaded blocks from the start,
    or None if the backend deferred a lookup."""
    hit_count = 0
    defer_lookup = False
    for key in keys:
        result = self.manager.lookup(key, req_context)
        if result is None:
            defer_lookup = True
            # continue lookup to allow manager to kick-off async lookups
            # for all blocks (until a miss is detected)
            result = True
        if not result:
            break
        hit_count += 1
    return hit_count if not defer_lookup else None

_sliding_window_lookup

_sliding_window_lookup(
    keys: Sequence[OffloadKey],
    sliding_window_size: int,
    req_context: ReqContext,
) -> int | None

Return the end index (in keys) of the last run of sliding_window_size consecutive hits, scanning from the end. Returns 0 on miss, None if the backend deferred a lookup.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def _sliding_window_lookup(
    self,
    keys: Sequence[OffloadKey],
    sliding_window_size: int,
    req_context: ReqContext,
) -> int | None:
    """Return the end index (in `keys`) of the last run of
    `sliding_window_size` consecutive hits, scanning from the end.
    Returns 0 on miss, None if the backend deferred a lookup."""
    defer_lookup = False
    consecutive_hits = 0
    for idx in range(len(keys) - 1, -1, -1):
        result = self.manager.lookup(keys[idx], req_context)
        if result is None:
            defer_lookup = True
            # continue lookup to allow manager to kick-off async lookups
            # for all blocks (until a hit is detected)
            result = False
        if not result:
            consecutive_hits = 0
        else:
            consecutive_hits += 1
            if consecutive_hits == sliding_window_size:
                return idx + sliding_window_size if not defer_lookup else None
    return consecutive_hits if not defer_lookup else None

get_num_new_matched_tokens

get_num_new_matched_tokens(
    request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]

Get number of new tokens that can be loaded beyond the num_computed_tokens.

Parameters:

Name Type Description Default
request Request

the request object.

required
num_computed_tokens int

the number of locally computed tokens for this request

required

Returns:

Type Description
tuple[int | None, bool]

A tuple with the following elements: - The number of tokens that can be loaded beyond what is already computed. If None, it means that the connector needs more time to determine the number of matched tokens, and the scheduler should query for this request again later. - True if tokens will be loaded asynchronously (between scheduler steps).

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def get_num_new_matched_tokens(
    self, request: Request, num_computed_tokens: int
) -> tuple[int | None, bool]:
    """
    Get number of new tokens that can be loaded beyond the
    num_computed_tokens.

    Args:
        request (Request): the request object.
        num_computed_tokens (int): the number of locally
            computed tokens for this request

    Returns:
        A tuple with the following elements:
            - The number of tokens that can be loaded beyond what is
              already computed.
              If None, it means that the connector needs more time to
              determine the number of matched tokens, and the scheduler
              should query for this request again later.
            - `True` if tokens will be loaded asynchronously
              (between scheduler steps).
    """
    is_new_request = False
    if req_status := self._req_status.get(request.request_id):
        # make sure block IDs are cleared
        for group_state in req_status.group_states:
            group_state.block_ids.clear()
    else:
        is_new_request = True
        req_status = RequestOffloadState(config=self.config, req=request)
        self._req_status[request.request_id] = req_status

    req_status.update_offload_keys()
    req_status.num_locally_computed_tokens = num_computed_tokens

    num_hit_tokens = self._lookup(req_status)
    if is_new_request:
        req_status.update_num_hit_blocks(
            num_computed_tokens + (num_hit_tokens or 0)
        )

    self._touch(req_status)

    return num_hit_tokens, bool(num_hit_tokens)

request_finished

request_finished(
    request: Request,
) -> tuple[bool, dict[str, Any] | None]

Called when a request has finished, before its blocks are freed.

Returns:

Type Description
bool

True if the request is being saved/sent asynchronously and blocks

dict[str, Any] | None

should not be freed until the request_id is returned from

tuple[bool, dict[str, Any] | None]

get_finished().

tuple[bool, dict[str, Any] | None]

Optional KVTransferParams to be included in the request outputs

tuple[bool, dict[str, Any] | None]

returned by the engine.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def request_finished(
    self,
    request: Request,
) -> tuple[bool, dict[str, Any] | None]:
    """
    Called when a request has finished, before its blocks are freed.

    Returns:
        True if the request is being saved/sent asynchronously and blocks
        should not be freed until the request_id is returned from
        get_finished().
        Optional KVTransferParams to be included in the request outputs
        returned by the engine.
    """
    # TODO(orozery): possibly kickoff offload for last block
    # which may have been deferred due to async scheduling
    req_status = self._req_status.get(request.request_id)
    if req_status is None:
        return False, None
    if not req_status.transfer_jobs:
        del self._req_status[request.request_id]
        return False, None
    # Pending stores will outlive the request's block ownership.
    # Register them so future block reuse triggers a flush.
    for job_id in req_status.transfer_jobs:
        job_status = self._jobs[job_id]
        for bid in job_status.non_sliding_window_block_ids or ():
            self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id)
    return False, None

take_events

take_events() -> Iterable[KVCacheEvent]

Take the KV cache events from the connector.

Returns:

Type Description
Iterable[KVCacheEvent]

A list of KV cache events.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def take_events(self) -> Iterable[KVCacheEvent]:
    """Take the KV cache events from the connector.

    Returns:
        A list of KV cache events.
    """
    for event in self.manager.take_events():
        block_hashes = [get_offload_block_hash(key) for key in event.keys]
        if event.removed:
            yield BlockRemoved(block_hashes=block_hashes, medium=event.medium)
        else:
            yield BlockStored(
                block_hashes=block_hashes,
                parent_block_hash=None,
                token_ids=[],
                lora_id=None,
                block_size=0,
                medium=event.medium,
                lora_name=None,
            )

update_connector_output

update_connector_output(
    connector_output: KVConnectorOutput,
)

Update KVConnector state from worker-side connectors output.

Parameters:

Name Type Description Default
connector_output KVConnectorOutput

the worker-side connectors output.

required
Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
def update_connector_output(self, connector_output: KVConnectorOutput):
    """
    Update KVConnector state from worker-side connectors output.

    Args:
        connector_output (KVConnectorOutput): the worker-side
            connectors output.
    """
    meta = connector_output.kv_connector_worker_meta
    if not isinstance(meta, OffloadingWorkerMetadata):
        assert meta is None
        meta = OffloadingWorkerMetadata()
    for job_id, count in meta.completed_jobs.items():
        assert count > 0
        job_status = self._jobs[job_id]
        job_status.pending_count -= count
        if job_status.pending_count > 0:
            continue
        assert job_status.pending_count == 0

        if job_status.is_store:
            self.manager.complete_store(job_status.keys)
        else:
            self.manager.complete_load(job_status.keys)
            if self._blocks_being_loaded:
                self._blocks_being_loaded.difference_update(job_status.keys)

        req_status = self._req_status[job_status.req_id]
        if self._block_id_to_pending_jobs:
            # Sliding window blocks are tracked from store creation
            # and must be cleaned up unconditionally.
            self._remove_pending_job(job_id, job_status.sliding_window_block_ids)
            # Non-sliding-window blocks are only tracked after
            # request_finished, so only clean up for finished requests.
            if req_status.req.is_finished():
                self._remove_pending_job(
                    job_id, job_status.non_sliding_window_block_ids
                )

        del self._jobs[job_id]
        req_status.transfer_jobs.remove(job_id)
        if not req_status.transfer_jobs and req_status.req.is_finished():
            del self._req_status[job_status.req_id]

TransferJobStatus dataclass

Tracks scheduler-side state for a single transfer job.

Source code in vllm/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py
@dataclass(slots=True)
class TransferJobStatus:
    """Tracks scheduler-side state for a single transfer job."""

    req_id: ReqId
    # Number of workers still pending. Starts at num_workers,
    # decremented as each worker reports completion. Job is done at 0.
    pending_count: int
    # Offload keys this job covers; passed to manager.complete_*().
    keys: set[OffloadKey]
    is_store: bool
    # Store src block IDs whose ref_cnt protects them while the request
    # runs. Only registered in _block_id_to_pending_jobs on request_finished.
    non_sliding_window_block_ids: list[int] | None = None
    # Store src block IDs that may be freed before the request finishes.
    # Registered in _block_id_to_pending_jobs at store creation time.
    sliding_window_block_ids: list[int] | None = None