Skip to content

vllm.model_executor.models.qwen2_5_vl

Inference-only Qwen2.5-VL model compatible with HuggingFace weights.

Qwen2_5_VLForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsEncoderCudaGraph, SupportsLoRA, SupportsPP, SupportsQuant, SupportsEagle, SupportsEagle3, SupportsMultiModalPruning, SupportsMRoPE

Source code in vllm/model_executor/models/qwen2_5_vl.py
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
@MULTIMODAL_REGISTRY.register_processor(
    Qwen2_5_VLMultiModalProcessor,
    info=Qwen2_5_VLProcessingInfo,
    dummy_inputs=Qwen2_5_VLDummyInputsBuilder,
)
class Qwen2_5_VLForConditionalGeneration(
    nn.Module,
    SupportsMultiModal,
    SupportsEncoderCudaGraph,
    SupportsLoRA,
    SupportsPP,
    SupportsQuant,
    SupportsEagle,
    SupportsEagle3,
    SupportsMultiModalPruning,
    SupportsMRoPE,
):
    packed_modules_mapping = {
        "qkv_proj": ["q_proj", "k_proj", "v_proj"],
        "gate_up_proj": ["gate_proj", "up_proj"],
        "qkv": ["qkv"],  # For vision tower's already-packed QKV
    }

    # To ensure correct weight loading and mapping.
    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.language_model.": "language_model.model.",
            "model.visual.": "visual.",
            # mapping for original checkpoint
            "lm_head.": "language_model.lm_head.",
            "model.": "language_model.model.",
        }
    )

    supports_encoder_tp_data = True

    def iter_mm_grid_thw(
        self, mm_features: list[MultiModalFeatureSpec]
    ) -> Iterator[tuple[int, int, int, int, float]]:
        """
        Iterate over multimodal features and yield grid information.

        Args:
            mm_features: List of multimodal feature specifications

        Yields:
            Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
        """
        spatial_merge_size = self.config.vision_config.spatial_merge_size
        tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
        for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
            offset = mm_feature.mm_position.offset
            if mm_feature.modality == "image":
                t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
                assert t == 1, f"Image must have 1 frame, got {t}"
                yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
            elif mm_feature.modality == "video":
                t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
                second_per_grid_ts = 1.0
                if mm_feature.data.get("second_per_grid_ts", None):
                    second_per_grid_ts = mm_feature.data[
                        "second_per_grid_ts"
                    ].data.item()
                t_factor = second_per_grid_ts * tokens_per_second
                yield (
                    offset,
                    t,
                    h // spatial_merge_size,
                    w // spatial_merge_size,
                    t_factor,
                )
            else:
                raise ValueError(f"Unsupported modality: {mm_feature.modality}")

    def get_mrope_input_positions(
        self,
        input_tokens: list[int],
        mm_features: list[MultiModalFeatureSpec],
    ) -> tuple[torch.Tensor, int]:
        llm_pos_ids_list: list = []
        st = 0

        for (
            offset,
            llm_grid_t,
            llm_grid_h,
            llm_grid_w,
            t_factor,
        ) in self.iter_mm_grid_thw(mm_features):
            text_len = offset - st
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

            grid_indices = np.indices((llm_grid_t, llm_grid_h, llm_grid_w))
            if t_factor != 1.0:
                grid_indices[0] = (grid_indices[0] * t_factor).astype(np.int64)
            llm_pos_ids_list.append(grid_indices.reshape(3, -1) + text_len + st_idx)
            st = offset + llm_grid_t * llm_grid_h * llm_grid_w

        if st < len(input_tokens):
            st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
            text_len = len(input_tokens) - st
            llm_pos_ids_list.append(
                np.broadcast_to(np.arange(text_len), (3, text_len)) + st_idx
            )

        llm_positions = np.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1)
        mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()

        return torch.from_numpy(llm_positions), mrope_position_delta

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return "<|vision_start|><|image_pad|><|vision_end|>"
        if modality.startswith("video"):
            return "<|vision_start|><|video_pad|><|vision_end|>"

        raise ValueError("Only image or video modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
        multimodal_config = vllm_config.model_config.multimodal_config

        self.use_data_parallel = multimodal_config.mm_encoder_tp_mode == "data"
        self.config = config
        self.model_config = vllm_config.model_config
        self.vllm_config = vllm_config
        self.multimodal_config = multimodal_config
        self.video_pruning_rate = multimodal_config.video_pruning_rate
        self.is_multimodal_pruning_enabled = (
            multimodal_config.is_multimodal_pruning_enabled()
        )

        with self._mark_tower_model(vllm_config, {"image", "video"}):
            self.visual = Qwen2_5_VisionTransformer(
                vision_config=config.vision_config,
                norm_eps=getattr(config, "rms_norm_eps", 1e-6),
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "visual"),
            )

        with self._mark_language_model(vllm_config):
            self.language_model = init_vllm_registered_model(
                vllm_config=vllm_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Qwen2ForCausalLM"],
            )

        self.make_empty_intermediate_tensors = (
            self.language_model.make_empty_intermediate_tensors
        )

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> Qwen2_5_VLImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        image_grid_thw = kwargs.pop("image_grid_thw", None)

        if pixel_values is None and image_embeds is None:
            return None

        if pixel_values is not None:
            return Qwen2_5_VLImagePixelInputs(
                type="pixel_values",
                pixel_values=pixel_values,
                image_grid_thw=image_grid_thw,
            )

        if image_embeds is not None:
            return Qwen2_5_VLImageEmbeddingInputs(
                type="image_embeds",
                image_embeds=image_embeds,
                image_grid_thw=image_grid_thw,
            )

    def _parse_and_validate_video_input(
        self, **kwargs: object
    ) -> Qwen2_5_VLVideoInputs | None:
        pixel_values_videos = kwargs.pop("pixel_values_videos", None)
        video_embeds = kwargs.pop("video_embeds", None)
        video_grid_thw = kwargs.pop("video_grid_thw", None)
        second_per_grid_ts = kwargs.pop("second_per_grid_ts", None)

        if pixel_values_videos is None and video_embeds is None:
            return None

        if pixel_values_videos is not None:
            return Qwen2_5_VLVideoPixelInputs(
                type="pixel_values_videos",
                pixel_values_videos=pixel_values_videos,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
            )

        if video_embeds is not None:
            return Qwen2_5_VLVideoEmbeddingInputs(
                type="video_embeds",
                video_embeds=video_embeds,
                video_grid_thw=video_grid_thw,
                second_per_grid_ts=second_per_grid_ts,
            )

    def _process_image_input(
        self, image_input: Qwen2_5_VLImageInputs
    ) -> tuple[torch.Tensor, ...]:
        grid_thw = image_input["image_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if image_input["type"] == "image_embeds":
            image_embeds = image_input["image_embeds"].type(self.visual.dtype)
        else:
            pixel_values = image_input["pixel_values"]
            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(
                    self.visual, pixel_values, grid_thw_list, rope_type="rope_3d"
                )
            else:
                image_embeds = self.visual(pixel_values, grid_thw=grid_thw_list)

        # Split concatenated embeddings for each image item.
        merge_size = self.visual.spatial_merge_size
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
        return image_embeds.split(sizes)

    def _postprocess_image_embeds_evs(
        self,
        image_embeds_split: tuple[torch.Tensor, ...],
        image_input: Qwen2_5_VLImageInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Append mrope positions for each for images.
        This is necessary to recover correct mrope
        positions after video pruning

        Args:
            image_embeds_split: Tuple of image embeddings for
                each image item.
            image_input: Image input data.

        Returns:
            Tuple of image embeddings for each image item.
            Resulting embeddings will have extra 4 channels for
            computed mrope positions.
        """
        merge_size = self.visual.spatial_merge_size
        grid_thw = image_input["image_grid_thw"]
        grid_thw_list = grid_thw.tolist()
        image_embeds_out = []
        for emb, size in zip(image_embeds_split, grid_thw_list):
            positions = compute_mrope_for_media(size, merge_size).to(emb.device)
            emb = torch.cat([emb, positions], dim=1)
            image_embeds_out.append(emb)
        image_embeds_split = image_embeds_out
        return tuple(image_embeds_split)

    def _process_video_input(
        self, video_input: Qwen2_5_VLVideoInputs
    ) -> tuple[torch.Tensor, ...]:
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()

        if video_input["type"] == "video_embeds":
            video_embeds = video_input["video_embeds"].type(self.visual.dtype)
        else:
            pixel_values_videos = video_input["pixel_values_videos"]
            if self.use_data_parallel:
                return run_dp_sharded_mrope_vision_model(
                    self.visual,
                    pixel_values_videos,
                    grid_thw_list,
                    rope_type="rope_3d",
                )
            else:
                video_embeds = self.visual(pixel_values_videos, grid_thw=grid_thw_list)

        # Split concatenated embeddings for each video item.
        merge_size = self.visual.spatial_merge_size
        sizes = (grid_thw.prod(-1) // merge_size // merge_size).tolist()
        return video_embeds.split(sizes)

    def _postprocess_video_embeds_evs(
        self,
        video_embeds_split: tuple[torch.Tensor, ...],
        video_input: Qwen2_5_VLVideoInputs,
    ) -> tuple[torch.Tensor, ...]:
        """
        Prunes video embeddings via Efficient Video Sampling (EVS)
        and then appends mrope positions for each retained embeddings

        Args:
            video_embeds_split: Tuple of video embeddings for each video item.
            video_input: Video input data.

        Returns:
            Tuple of video embeddings for each video item.
            Resulting embeddings will have extra 4 channels for
            computed mrope positions.
        """
        grid_thw = video_input["video_grid_thw"]
        assert grid_thw.ndim == 2
        grid_thw_list = grid_thw.tolist()
        merge_size = self.visual.spatial_merge_size

        # Cast to long to match the original code
        # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
        second_per_grid_ts = video_input.get("second_per_grid_ts")
        if second_per_grid_ts is None:
            raise ValueError(
                "second_per_grid_ts is required when video_pruning_rate > 0 "
                "is enabled for video inputs, including the video_embeds path."
            )
        second_per_grid_ts = second_per_grid_ts.long()
        tokens_per_second = self.config.vision_config.tokens_per_second

        video_embeds_out = []
        for emb, size, video_second_per_grid_t in zip(
            video_embeds_split, grid_thw_list, second_per_grid_ts
        ):
            # For each video, we compute retention mask using EVS
            retention_mask = compute_retention_mask(
                emb,
                size,
                spatial_merge_size=self.visual.spatial_merge_size,
                q=self.video_pruning_rate,
            )
            positions = compute_mrope_for_media(
                size,
                merge_size,
                tokens_per_second=tokens_per_second,
                video_second_per_grid=video_second_per_grid_t.item(),
            ).to(emb.device)

            emb = emb[retention_mask]
            positions = positions[retention_mask]
            emb = torch.cat([emb, positions], dim=1)
            video_embeds_out.append(emb)
        return tuple(video_embeds_out)

    def recompute_mrope_positions(
        self,
        input_ids: list[int],
        multimodal_embeddings: tuple[torch.Tensor, ...],
        mrope_positions: torch.LongTensor,
        num_computed_tokens: int,
    ) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
        """
        Update part of input mrope positions (starting with
        num_computed_tokens index). Original mrope_positions are computed
        for unpruned sequence and becomes incorrect once pruning occurs,
        so once we prune media tokens we should reflect this in the
        mrope_positions before we feed it to LLM.

        Args:
            input_ids: (N,) All input tokens of the prompt (Containing
                entire sequence).
            multimodal_embeddings: Tuple of multimodal embeddings.
            mrope_positions: Existing mrope positions (3, N) for entire
                sequence
            num_computed_tokens: A number of computed tokens so far.

        Returns:
            Tuple of (multimodal_embeddings, mrope_positions,
                mrope_position_delta).
        """
        image_token_id = self.config.image_token_id
        video_token_id = self.config.video_token_id
        vision_start_token_id = self.config.vision_start_token_id

        # Device
        device = (
            multimodal_embeddings[0].device
            if len(multimodal_embeddings)
            else mrope_positions.device
        )

        # Tensors
        input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)

        mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
        mm_embeddings_pos = [
            mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
        ]

        positions, mrope_positions_delta = recompute_mrope_positions(
            input_ids_t,
            mm_embeddings_pos,
            mrope_positions,
            num_computed_tokens,
            vision_start_token_id,
            image_token_id,
            video_token_id,
        )

        return tuple(mm_embeddings_out), positions, mrope_positions_delta

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key in ("pixel_values_videos", "video_embeds")
                and "video" not in mm_input_by_modality
            ):
                mm_input_by_modality["video"] = self._parse_and_validate_video_input(
                    **kwargs
                )
        return mm_input_by_modality

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
        if not mm_input_by_modality:
            return []

        # The result multimodal_embeddings is tuple of tensors, with each
        # tensor correspoending to a multimodal data item (image or video).
        multimodal_embeddings: tuple[torch.Tensor, ...] = ()

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                image_embeddings = self._process_image_input(multimodal_input)
                if self.is_multimodal_pruning_enabled:
                    image_embeddings = self._postprocess_image_embeds_evs(
                        image_embeddings, multimodal_input
                    )
                multimodal_embeddings += tuple(image_embeddings)
            if modality == "video":
                video_embeddings = self._process_video_input(multimodal_input)
                if self.is_multimodal_pruning_enabled:
                    video_embeddings = self._postprocess_video_embeds_evs(
                        video_embeddings, multimodal_input
                    )
                multimodal_embeddings += tuple(video_embeddings)
        return multimodal_embeddings

    # -- SupportsEncoderCudaGraph protocol methods --

    def get_encoder_cudagraph_config(self):
        from vllm.v1.worker.encoder_cudagraph_defs import (
            EncoderCudaGraphConfig,
        )

        # NOTE: With EVS pruning enabled, multimodal embeddings are post-processed
        # (append positions for image and prune+append positions for video) in
        # embed_multimodal(). The encoder CUDA graph path bypasses that postprocess
        # hook, so disable CUDA graph for all modalities to avoid inconsistent
        # embedding formats between eager and cudagraph paths.
        modalities = [] if self.is_multimodal_pruning_enabled else ["image", "video"]

        return EncoderCudaGraphConfig(
            modalities=modalities,
            input_key_by_modality={
                "image": "pixel_values",
                "video": "pixel_values_videos",
            },
            buffer_keys=[
                "rotary_pos_emb_cos",
                "rotary_pos_emb_sin",
                "window_index",
                "reverse_indices",
                "cu_seqlens",
                "cu_window_seqlens",
                "max_seqlen_full",
                "max_seqlen_window",
            ],
            out_hidden_size=self.visual.out_hidden_size,
        )

    def get_input_modality(
        self,
        mm_kwargs: dict[str, Any],
    ) -> str:
        if "image_grid_thw" in mm_kwargs:
            return "image"
        return "video"

    def get_max_frames_per_video(self) -> int:
        mm_registry = MULTIMODAL_REGISTRY
        info = mm_registry.get_processing_info(self.model_config)
        max_frames_per_video = info.get_num_frames_with_most_features(
            seq_len=self.model_config.max_model_len,
            mm_counts={"video": self.multimodal_config.get_limit_per_prompt("video")},
        )
        return max_frames_per_video

    def get_encoder_cudagraph_budget_range(
        self,
        vllm_config: VllmConfig,
    ) -> tuple[int, int]:
        # Min: estimated smallest possible encoder input.
        # 224x224 image → 16x16 patches (patch_size=14)
        #                 spatial_merge_size=2 → 8x8 = 64 tokens
        min_budget = 64
        # Max: capped by max_num_batched_tokens
        max_budget = min(
            vllm_config.scheduler_config.max_num_batched_tokens,
            self.model_config.max_model_len,
        )
        return (min_budget, max_budget)

    def _get_pixel_values_by_modality(
        self,
        mm_kwargs: dict[str, Any],
    ) -> torch.Tensor:
        if self.get_input_modality(mm_kwargs) == "image":
            pixel_values = mm_kwargs["pixel_values"]
        else:
            pixel_values = mm_kwargs["pixel_values_videos"]
        return pixel_values

    def _get_grid_thw_by_modality(
        self,
        mm_kwargs: dict[str, Any],
    ) -> list[tuple[int, int, int]]:
        grid_thw_key = f"{self.get_input_modality(mm_kwargs)}_grid_thw"
        grid_thw = mm_kwargs[grid_thw_key]
        if not isinstance(grid_thw, list):
            grid_thw = grid_thw.tolist()
        return grid_thw

    def get_encoder_cudagraph_num_items(
        self,
        mm_kwargs: dict[str, Any],
    ) -> int:
        return len(self._get_grid_thw_by_modality(mm_kwargs))

    def get_encoder_cudagraph_per_item_output_tokens(
        self,
        mm_kwargs: dict[str, Any],
    ) -> list[int]:
        m = self.visual.spatial_merge_size
        grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
        return [t * (h // m) * (w // m) for t, h, w in grid_thw]

    def get_encoder_cudagraph_per_item_input_sizes(
        self,
        mm_kwargs: dict[str, Any],
    ) -> list[int]:
        grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
        return [t * h * w for t, h, w in grid_thw]

    def select_encoder_cudagraph_items(
        self,
        mm_kwargs: dict[str, Any],
        indices: list[int],
    ) -> dict[str, Any]:
        grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
        pixel_values = self._get_pixel_values_by_modality(mm_kwargs)

        if len(indices) == 0:
            if self.get_input_modality(mm_kwargs) == "image":
                return {
                    "pixel_values": pixel_values[:0],
                    "image_grid_thw": [],
                }
            else:
                return {
                    "pixel_values_videos": pixel_values[:0],
                    "video_grid_thw": [],
                }

        # Compute cumulative patch offsets for slicing pixel_values
        patches_per_item = [t * h * w for t, h, w in grid_thw]
        cum_patches = [0]
        for p in patches_per_item:
            cum_patches.append(cum_patches[-1] + p)

        selected_pv = torch.cat(
            [pixel_values[cum_patches[i] : cum_patches[i + 1]] for i in indices]
        )
        selected_grid = [grid_thw[i] for i in indices]

        if self.get_input_modality(mm_kwargs) == "image":
            return {
                "pixel_values": selected_pv,
                "image_grid_thw": selected_grid,
            }
        else:
            return {
                "pixel_values_videos": selected_pv,
                "video_grid_thw": selected_grid,
            }

    def prepare_encoder_cudagraph_capture_inputs(
        self,
        token_budget: int,
        max_batch_size: int,
        max_frames_per_batch: int,
        device: torch.device,
        dtype: torch.dtype,
    ):
        from vllm.v1.worker.encoder_cudagraph_defs import (
            EncoderCudaGraphCaptureInputs,
        )

        spatial_merge_size = self.visual.spatial_merge_size
        max_window_seqs_per_batch = min(
            self.vllm_config.scheduler_config.max_num_batched_tokens,
            self.model_config.max_model_len,
        )
        # Use ceil here (not floor) so total captured capacity is never smaller
        # than token_budget when token_budget is not divisible by max_batch_size
        # (e.g., 324 budget with max_batch_size=8). Floor under-allocates
        # input_buffer and can fail replay copy for valid single-item batches.
        per_mm_item_output = (token_budget + max_batch_size - 1) // max_batch_size

        frames_per_item = max_frames_per_batch // max_batch_size
        if frames_per_item > 1:
            # Build the capture grid using a video-format layout so that
            # cu_seqlens is sized for video replays from the start.
            # cu_seqlens has one entry per attention sequence (one per frame),
            # so using T > 1 per item makes the buffer large enough without
            # relying solely on padding.
            # Ceiling ensures frames_per_item * tokens_per_frame >= per_mm_item_output
            # so the pixel_values buffer covers any valid single-item replay.
            tokens_per_frame = (
                per_mm_item_output + frames_per_item - 1
            ) // frames_per_item
            # Video-format grid_config (T=frames_per_item).
            grid_config = [
                [
                    frames_per_item,
                    spatial_merge_size,
                    tokens_per_frame * spatial_merge_size,
                ]
                for _ in range(max_batch_size)
            ]
        else:
            # Image-format grid_config (T=1).
            grid_config = [
                [1, spatial_merge_size, per_mm_item_output * spatial_merge_size]
                for _ in range(max_batch_size)
            ]

        # Create dummy pixel_values
        patch_embed = self.visual.patch_embed
        in_channels = patch_embed.proj.in_channels
        patch_size = patch_embed.patch_size
        temporal_patch_size = patch_embed.temporal_patch_size
        total_patches = sum(t * h * w for t, h, w in grid_config)
        flattened_patch_size = (
            in_channels * temporal_patch_size * patch_size * patch_size
        )
        dummy_pixel_values = torch.randn(
            total_patches, flattened_patch_size, device=device, dtype=dtype
        )

        # Override max_seqlen with a safe upper bound for capture.
        # max_seqlen.item() gets baked into the CUDA graph (not replayed),
        # so the capture value must cover any replay scenario.
        # Worst case: 1 item consuming the full budget ->
        # seq_len = token_budget * spatial_merge_size^2.
        # For window-attention, each local window is bounded by fixed geometry:
        # (window_size / patch_size / spatial_merge_size)^2 windows in merged
        # token space, multiplied by spatial_merge_size^2 to map back to the
        # unmerged sequence length used by attention kernels.
        vit_merger_window_size = (
            self.visual.window_size
            // self.visual.spatial_merge_size
            // self.visual.patch_size
        )
        max_seqlen_window_override = vit_merger_window_size**2 * (spatial_merge_size**2)
        buffers = self.visual.prepare_encoder_metadata(
            grid_config,
            max_batch_size=max_batch_size,
            max_frames_per_batch=max_frames_per_batch,
            max_window_seqs_per_batch=max_window_seqs_per_batch,
            max_seqlen_override=token_budget * (spatial_merge_size**2),
            max_seqlen_window_override=max_seqlen_window_override,
            device=device,
        )

        # Just use image-modality dummy input_buffer for capturing, since it's also
        # compatible for video inputs (has the same shape: [num_patches, C*T*P*P]).
        mm_kwargs = {
            "pixel_values": dummy_pixel_values,
            "image_grid_thw": grid_config,
        }

        return EncoderCudaGraphCaptureInputs(
            mm_kwargs=mm_kwargs,
            buffers=buffers,
        )

    def prepare_encoder_cudagraph_replay_buffers(
        self,
        mm_kwargs: dict[str, Any],
        max_batch_size: int,
        max_frames_per_batch: int,
    ):
        modality = self.get_input_modality(mm_kwargs)
        grid_thw_list = self._get_grid_thw_by_modality(mm_kwargs)

        if modality == "image":
            buffers = self.visual.prepare_encoder_metadata(
                grid_thw_list,
                max_batch_size=max_batch_size,
                max_window_seqs_per_batch=min(
                    self.vllm_config.scheduler_config.max_num_batched_tokens,
                    self.model_config.max_model_len,
                ),
            )
        else:
            buffers = self.visual.prepare_encoder_metadata(
                grid_thw_list,
                max_frames_per_batch=max_frames_per_batch,
                max_window_seqs_per_batch=min(
                    self.vllm_config.scheduler_config.max_num_batched_tokens,
                    self.model_config.max_model_len,
                ),
            )

        return EncoderCudaGraphReplayBuffers(buffers=buffers)

    def encoder_cudagraph_forward(
        self,
        mm_kwargs: dict[str, Any],
        buffers: dict[str, torch.Tensor],
    ) -> torch.Tensor:
        pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
        grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
        return self.visual(pixel_values, grid_thw, encoder_metadata=buffers)

    def encoder_eager_forward(
        self,
        mm_kwargs: dict[str, Any],
    ) -> torch.Tensor:
        pixel_values = self._get_pixel_values_by_modality(mm_kwargs)
        grid_thw = self._get_grid_thw_by_modality(mm_kwargs)
        return self.visual(pixel_values, grid_thw)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor | IntermediateTensors:
        """Run forward pass for Qwen2.5-VL.

        Args:
            input_ids: Flattened (concatenated) input_ids corresponding to a
                batch.
            positions: Flattened (concatenated) position ids corresponding to a
                batch. **NOTE**: If mrope is enabled (default setting for
                Qwen2.5-VL opensource models), the shape will be `(3, seq_len)`,
                otherwise it will be `(seq_len,).
        """

        if intermediate_tensors is not None:
            inputs_embeds = None

        hidden_states = self.language_model.model(
            input_ids=input_ids,
            positions=positions,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="visual.merger.",
            tower_model="visual.",
        )

    def get_num_mm_encoder_tokens(
        self,
        num_image_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size

        return num_image_tokens * merge_size**2

    def get_num_mm_connector_tokens(
        self,
        num_vision_tokens: int,
    ) -> int:
        hf_config = self.config
        vision_config = hf_config.vision_config
        merge_size = vision_config.spatial_merge_size
        return num_vision_tokens // merge_size**2

_postprocess_image_embeds_evs

_postprocess_image_embeds_evs(
    image_embeds_split: tuple[Tensor, ...],
    image_input: Qwen2_5_VLImageInputs,
) -> tuple[Tensor, ...]

Append mrope positions for each for images. This is necessary to recover correct mrope positions after video pruning

Parameters:

Name Type Description Default
image_embeds_split tuple[Tensor, ...]

Tuple of image embeddings for each image item.

required
image_input Qwen2_5_VLImageInputs

Image input data.

required

Returns:

Type Description
Tensor

Tuple of image embeddings for each image item.

...

Resulting embeddings will have extra 4 channels for

tuple[Tensor, ...]

computed mrope positions.

Source code in vllm/model_executor/models/qwen2_5_vl.py
def _postprocess_image_embeds_evs(
    self,
    image_embeds_split: tuple[torch.Tensor, ...],
    image_input: Qwen2_5_VLImageInputs,
) -> tuple[torch.Tensor, ...]:
    """
    Append mrope positions for each for images.
    This is necessary to recover correct mrope
    positions after video pruning

    Args:
        image_embeds_split: Tuple of image embeddings for
            each image item.
        image_input: Image input data.

    Returns:
        Tuple of image embeddings for each image item.
        Resulting embeddings will have extra 4 channels for
        computed mrope positions.
    """
    merge_size = self.visual.spatial_merge_size
    grid_thw = image_input["image_grid_thw"]
    grid_thw_list = grid_thw.tolist()
    image_embeds_out = []
    for emb, size in zip(image_embeds_split, grid_thw_list):
        positions = compute_mrope_for_media(size, merge_size).to(emb.device)
        emb = torch.cat([emb, positions], dim=1)
        image_embeds_out.append(emb)
    image_embeds_split = image_embeds_out
    return tuple(image_embeds_split)

_postprocess_video_embeds_evs

_postprocess_video_embeds_evs(
    video_embeds_split: tuple[Tensor, ...],
    video_input: Qwen2_5_VLVideoInputs,
) -> tuple[Tensor, ...]

Prunes video embeddings via Efficient Video Sampling (EVS) and then appends mrope positions for each retained embeddings

Parameters:

Name Type Description Default
video_embeds_split tuple[Tensor, ...]

Tuple of video embeddings for each video item.

required
video_input Qwen2_5_VLVideoInputs

Video input data.

required

Returns:

Type Description
Tensor

Tuple of video embeddings for each video item.

...

Resulting embeddings will have extra 4 channels for

tuple[Tensor, ...]

computed mrope positions.

Source code in vllm/model_executor/models/qwen2_5_vl.py
def _postprocess_video_embeds_evs(
    self,
    video_embeds_split: tuple[torch.Tensor, ...],
    video_input: Qwen2_5_VLVideoInputs,
) -> tuple[torch.Tensor, ...]:
    """
    Prunes video embeddings via Efficient Video Sampling (EVS)
    and then appends mrope positions for each retained embeddings

    Args:
        video_embeds_split: Tuple of video embeddings for each video item.
        video_input: Video input data.

    Returns:
        Tuple of video embeddings for each video item.
        Resulting embeddings will have extra 4 channels for
        computed mrope positions.
    """
    grid_thw = video_input["video_grid_thw"]
    assert grid_thw.ndim == 2
    grid_thw_list = grid_thw.tolist()
    merge_size = self.visual.spatial_merge_size

    # Cast to long to match the original code
    # https://github.com/huggingface/transformers/blob/41980ce93e775f6c88500c51c8db7946fc6a2add/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py#L491 # noqa
    second_per_grid_ts = video_input.get("second_per_grid_ts")
    if second_per_grid_ts is None:
        raise ValueError(
            "second_per_grid_ts is required when video_pruning_rate > 0 "
            "is enabled for video inputs, including the video_embeds path."
        )
    second_per_grid_ts = second_per_grid_ts.long()
    tokens_per_second = self.config.vision_config.tokens_per_second

    video_embeds_out = []
    for emb, size, video_second_per_grid_t in zip(
        video_embeds_split, grid_thw_list, second_per_grid_ts
    ):
        # For each video, we compute retention mask using EVS
        retention_mask = compute_retention_mask(
            emb,
            size,
            spatial_merge_size=self.visual.spatial_merge_size,
            q=self.video_pruning_rate,
        )
        positions = compute_mrope_for_media(
            size,
            merge_size,
            tokens_per_second=tokens_per_second,
            video_second_per_grid=video_second_per_grid_t.item(),
        ).to(emb.device)

        emb = emb[retention_mask]
        positions = positions[retention_mask]
        emb = torch.cat([emb, positions], dim=1)
        video_embeds_out.append(emb)
    return tuple(video_embeds_out)

forward

forward(
    input_ids: Tensor | None,
    positions: Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: Tensor | None = None,
    **kwargs: object,
) -> Tensor | IntermediateTensors

Run forward pass for Qwen2.5-VL.

Parameters:

Name Type Description Default
input_ids Tensor | None

Flattened (concatenated) input_ids corresponding to a batch.

required
positions Tensor

Flattened (concatenated) position ids corresponding to a batch. NOTE: If mrope is enabled (default setting for Qwen2.5-VL opensource models), the shape will be (3, seq_len), otherwise it will be `(seq_len,).

required
Source code in vllm/model_executor/models/qwen2_5_vl.py
def forward(
    self,
    input_ids: torch.Tensor | None,
    positions: torch.Tensor,
    intermediate_tensors: IntermediateTensors | None = None,
    inputs_embeds: torch.Tensor | None = None,
    **kwargs: object,
) -> torch.Tensor | IntermediateTensors:
    """Run forward pass for Qwen2.5-VL.

    Args:
        input_ids: Flattened (concatenated) input_ids corresponding to a
            batch.
        positions: Flattened (concatenated) position ids corresponding to a
            batch. **NOTE**: If mrope is enabled (default setting for
            Qwen2.5-VL opensource models), the shape will be `(3, seq_len)`,
            otherwise it will be `(seq_len,).
    """

    if intermediate_tensors is not None:
        inputs_embeds = None

    hidden_states = self.language_model.model(
        input_ids=input_ids,
        positions=positions,
        intermediate_tensors=intermediate_tensors,
        inputs_embeds=inputs_embeds,
    )
    return hidden_states

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/qwen2_5_vl.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector="visual.merger.",
        tower_model="visual.",
    )

iter_mm_grid_thw

iter_mm_grid_thw(
    mm_features: list[MultiModalFeatureSpec],
) -> Iterator[tuple[int, int, int, int, float]]

Iterate over multimodal features and yield grid information.

Parameters:

Name Type Description Default
mm_features list[MultiModalFeatureSpec]

List of multimodal feature specifications

required

Yields:

Type Description
tuple[int, int, int, int, float]

Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image

Source code in vllm/model_executor/models/qwen2_5_vl.py
def iter_mm_grid_thw(
    self, mm_features: list[MultiModalFeatureSpec]
) -> Iterator[tuple[int, int, int, int, float]]:
    """
    Iterate over multimodal features and yield grid information.

    Args:
        mm_features: List of multimodal feature specifications

    Yields:
        Tuple of (offset, grid_t, grid_h, grid_w, t_factor) for each frame/image
    """
    spatial_merge_size = self.config.vision_config.spatial_merge_size
    tokens_per_second = getattr(self.config.vision_config, "tokens_per_second", 1.0)
    for mm_feature in sorted(mm_features, key=lambda f: f.mm_position.offset):
        offset = mm_feature.mm_position.offset
        if mm_feature.modality == "image":
            t, h, w = mm_feature.data["image_grid_thw"].data.tolist()
            assert t == 1, f"Image must have 1 frame, got {t}"
            yield offset, 1, h // spatial_merge_size, w // spatial_merge_size, 1.0
        elif mm_feature.modality == "video":
            t, h, w = mm_feature.data["video_grid_thw"].data.tolist()
            second_per_grid_ts = 1.0
            if mm_feature.data.get("second_per_grid_ts", None):
                second_per_grid_ts = mm_feature.data[
                    "second_per_grid_ts"
                ].data.item()
            t_factor = second_per_grid_ts * tokens_per_second
            yield (
                offset,
                t,
                h // spatial_merge_size,
                w // spatial_merge_size,
                t_factor,
            )
        else:
            raise ValueError(f"Unsupported modality: {mm_feature.modality}")

recompute_mrope_positions

recompute_mrope_positions(
    input_ids: list[int],
    multimodal_embeddings: tuple[Tensor, ...],
    mrope_positions: LongTensor,
    num_computed_tokens: int,
) -> tuple[tuple[Tensor, ...], Tensor, int]

Update part of input mrope positions (starting with num_computed_tokens index). Original mrope_positions are computed for unpruned sequence and becomes incorrect once pruning occurs, so once we prune media tokens we should reflect this in the mrope_positions before we feed it to LLM.

Parameters:

Name Type Description Default
input_ids list[int]

(N,) All input tokens of the prompt (Containing entire sequence).

required
multimodal_embeddings tuple[Tensor, ...]

Tuple of multimodal embeddings.

required
mrope_positions LongTensor

Existing mrope positions (3, N) for entire sequence

required
num_computed_tokens int

A number of computed tokens so far.

required

Returns:

Type Description
tuple[tuple[Tensor, ...], Tensor, int]

Tuple of (multimodal_embeddings, mrope_positions, mrope_position_delta).

Source code in vllm/model_executor/models/qwen2_5_vl.py
def recompute_mrope_positions(
    self,
    input_ids: list[int],
    multimodal_embeddings: tuple[torch.Tensor, ...],
    mrope_positions: torch.LongTensor,
    num_computed_tokens: int,
) -> tuple[tuple[torch.Tensor, ...], torch.Tensor, int]:
    """
    Update part of input mrope positions (starting with
    num_computed_tokens index). Original mrope_positions are computed
    for unpruned sequence and becomes incorrect once pruning occurs,
    so once we prune media tokens we should reflect this in the
    mrope_positions before we feed it to LLM.

    Args:
        input_ids: (N,) All input tokens of the prompt (Containing
            entire sequence).
        multimodal_embeddings: Tuple of multimodal embeddings.
        mrope_positions: Existing mrope positions (3, N) for entire
            sequence
        num_computed_tokens: A number of computed tokens so far.

    Returns:
        Tuple of (multimodal_embeddings, mrope_positions,
            mrope_position_delta).
    """
    image_token_id = self.config.image_token_id
    video_token_id = self.config.video_token_id
    vision_start_token_id = self.config.vision_start_token_id

    # Device
    device = (
        multimodal_embeddings[0].device
        if len(multimodal_embeddings)
        else mrope_positions.device
    )

    # Tensors
    input_ids_t = torch.as_tensor(input_ids, device=device, dtype=torch.long)

    mm_embeddings_out = [mm[:, :-4] for mm in multimodal_embeddings]
    mm_embeddings_pos = [
        mm[:, -4:].permute(1, 0).long() for mm in multimodal_embeddings
    ]

    positions, mrope_positions_delta = recompute_mrope_positions(
        input_ids_t,
        mm_embeddings_pos,
        mrope_positions,
        num_computed_tokens,
        vision_start_token_id,
        image_token_id,
        video_token_id,
    )

    return tuple(mm_embeddings_out), positions, mrope_positions_delta

Qwen2_5_VLImageEmbeddingInputs

Bases: TensorSchema

Dimensions
  • nf: Number of image features
  • hs: Hidden size
  • ni: Number of images
Historical context
  • image_embeds shape: (num_image_features, hidden_size)
  • num_image_features varies based on the number and resolution of the images.
  • hidden_size must match the hidden size of language model backbone.
  • image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format
Source code in vllm/model_executor/models/qwen2_5_vl.py
class Qwen2_5_VLImageEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of image features
        - hs: Hidden size
        - ni: Number of images

    Historical context:
        - image_embeds shape: (num_image_features, hidden_size)
        - num_image_features varies based on the number and resolution of the
          images.
        - hidden_size must match the hidden size of language model backbone.
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format
    """

    type: Literal["image_embeds"]

    image_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]

Qwen2_5_VLImagePixelInputs

Bases: TensorSchema

Dimensions
  • np: Number of patches
  • ni: Number of images
  • cps: Number of channels * patch_size * patch_size
Historical context
  • pixel_values shape: (num_patches, num_channels * patch_size * patch_size)
  • image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w) format.
Source code in vllm/model_executor/models/qwen2_5_vl.py
class Qwen2_5_VLImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - np: Number of patches
        - ni: Number of images
        - cps: Number of channels * patch_size * patch_size

    Historical context:
        - pixel_values shape: (num_patches, num_channels * patch_size *
          patch_size)
        - image_grid_thw shape: (num_images, 3) in (grid_t, grid_h, grid_w)
          format.
    """

    type: Literal["pixel_values"]

    pixel_values: Annotated[
        torch.Tensor,
        TensorShape("np", "cps"),
    ]

    image_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("ni", 3),
    ]

Qwen2_5_VLVideoEmbeddingInputs

Bases: TensorSchema

Dimensions
  • nf: Number of video features
  • hs: Hidden size
  • nv: Number of videos
Historical context
  • video_embeds shape: (num_video_features, hidden_size)
  • num_video_features varies based on the number and resolution of the videos.
  • hidden_size must match the hidden size of language model backbone.
  • video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format
  • second_per_grid_ts: The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when videos is not None.
  • timestamps: List of timestamp values (in seconds) for each frame after merging. Length equals the temporal dimension after merging.
Source code in vllm/model_executor/models/qwen2_5_vl.py
class Qwen2_5_VLVideoEmbeddingInputs(TensorSchema):
    """
    Dimensions:
        - nf: Number of video features
        - hs: Hidden size
        - nv: Number of videos

    Historical context:
        - video_embeds shape: (num_video_features, hidden_size)
        - num_video_features varies based on the number and resolution of the
          videos.
        - hidden_size must match the hidden size of language model backbone.
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
        - second_per_grid_ts: The video time interval (in seconds) for each
          grid along the temporal dimension in the 3D position IDs. Returned
          when `videos` is not `None`.
        - timestamps: List of timestamp values (in seconds) for each frame
          after merging. Length equals the temporal dimension after merging.
    """

    type: Literal["video_embeds"]

    video_embeds: Annotated[
        torch.Tensor,
        TensorShape("nf", "hs"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]

    second_per_grid_ts: Annotated[
        torch.Tensor | None,
        TensorShape("nv"),
    ] = None
    timestamps: list[list[float]] | None = None

Qwen2_5_VLVideoPixelInputs

Bases: TensorSchema

Dimensions
  • np: Number of patches
  • nv: Number of videos
  • ctps: Number of channels * temporal_patch_size * patch_size * patch_size
Historical context
  • pixel_values_videos shape: (num_patches, num_channels * temporal_patch_size * patch_size * patch_size)
  • video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w) format
  • second_per_grid_ts: The video time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs. Returned when videos is not None.
  • timestamps: List of timestamp values (in seconds) for each frame after merging. Length equals the temporal dimension after merging.
Source code in vllm/model_executor/models/qwen2_5_vl.py
class Qwen2_5_VLVideoPixelInputs(TensorSchema):
    """
    Dimensions:
        - np: Number of patches
        - nv: Number of videos
        - ctps: Number of channels * temporal_patch_size * patch_size *
          patch_size

    Historical context:
        - pixel_values_videos shape: (num_patches, num_channels *
          temporal_patch_size * patch_size * patch_size)
        - video_grid_thw shape: (num_videos, 3) in (grid_t, grid_h, grid_w)
          format
        - second_per_grid_ts: The video time interval (in seconds) for each
          grid along the temporal dimension in the 3D position IDs. Returned
          when `videos` is not `None`.
        - timestamps: List of timestamp values (in seconds) for each frame
          after merging. Length equals the temporal dimension after merging.
    """

    type: Literal["pixel_values_videos"]

    pixel_values_videos: Annotated[
        torch.Tensor,
        TensorShape("np", "ctps"),
    ]

    video_grid_thw: Annotated[
        torch.Tensor,
        TensorShape("nv", 3),
    ]

    second_per_grid_ts: Annotated[
        torch.Tensor | None,
        TensorShape("nv"),
    ]

    timestamps: list[list[float]] | None = None

Qwen2_5_VisionTransformer

Bases: Module

Source code in vllm/model_executor/models/qwen2_5_vl.py
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
class Qwen2_5_VisionTransformer(nn.Module):
    def __init__(
        self,
        vision_config: Qwen2_5_VLVisionConfig,
        norm_eps: float = 1e-6,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        patch_size = vision_config.patch_size
        temporal_patch_size = vision_config.temporal_patch_size
        in_channels = vision_config.in_channels
        depth = vision_config.depth
        self.hidden_size = vision_config.hidden_size
        self.num_heads = vision_config.num_heads
        self.out_hidden_size = vision_config.out_hidden_size

        # args for get_window_index_thw
        self.window_size = vision_config.window_size
        self.patch_size = vision_config.patch_size
        self.spatial_merge_size = vision_config.spatial_merge_size
        self.fullatt_block_indexes = vision_config.fullatt_block_indexes
        self.spatial_merge_unit = self.spatial_merge_size**2
        self.patch_embed = Qwen2_5_VisionPatchEmbed(
            patch_size=patch_size,
            temporal_patch_size=temporal_patch_size,
            in_channels=in_channels,
            hidden_size=self.hidden_size,
        )

        norm_layer = partial(RMSNorm, eps=norm_eps)
        head_dim = self.hidden_size // self.num_heads
        self.rotary_pos_emb = get_rope(
            head_size=head_dim,
            max_position=8192,
            is_neox_style=True,
            rope_parameters={"partial_rotary_factor": 0.5},
        )

        self.attn_backend = get_vit_attn_backend(
            head_size=head_dim,
            dtype=torch.get_default_dtype(),
        )

        self.blocks = nn.ModuleList(
            [
                Qwen2_5_VisionBlock(
                    dim=self.hidden_size,
                    num_heads=self.num_heads,
                    mlp_hidden_dim=vision_config.intermediate_size,
                    act_fn=get_act_and_mul_fn(vision_config.hidden_act),
                    norm_layer=norm_layer,
                    quant_config=quant_config,
                    prefix=f"{prefix}.blocks.{layer_idx}",
                )
                for layer_idx in range(depth)
            ]
        )

        self.merger = Qwen2_5_VisionPatchMerger(
            d_model=vision_config.out_hidden_size,
            context_dim=self.hidden_size,
            norm_layer=norm_layer,
            spatial_merge_size=self.spatial_merge_size,
            quant_config=quant_config,
            prefix=f"{prefix}.merger",
        )

    @property
    def dtype(self) -> torch.dtype:
        return self.patch_embed.proj.weight.dtype

    @property
    def device(self) -> torch.device:
        return self.patch_embed.proj.weight.device

    def rotary_pos_emb_thw(self, t, h, w):
        hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
        wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
        hpos_ids = (
            hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            .permute(0, 2, 1, 3)
            .flatten()
        )
        wpos_ids = (
            wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            .permute(0, 2, 1, 3)
            .flatten()
        )
        pos_ids = torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)
        max_size = max(h, w)

        # Use pre-computed cos_sin_cache from RotaryEmbedding
        cos, sin = self.rotary_pos_emb.get_cos_sin(max_size)

        cos_combined = cos[pos_ids].flatten(1)
        sin_combined = sin[pos_ids].flatten(1)

        cos_combined = cos_combined.reshape(
            cos_combined.shape[0] // self.spatial_merge_unit,
            self.spatial_merge_unit,
            -1,
        )
        sin_combined = sin_combined.reshape(
            sin_combined.shape[0] // self.spatial_merge_unit,
            self.spatial_merge_unit,
            -1,
        )

        return cos_combined, sin_combined

    def get_window_index_thw(self, grid_t, grid_h, grid_w):
        vit_merger_window_size = (
            self.window_size // self.spatial_merge_size // self.patch_size
        )

        llm_grid_h = grid_h // self.spatial_merge_size
        llm_grid_w = grid_w // self.spatial_merge_size
        index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
            grid_t, llm_grid_h, llm_grid_w
        )
        pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
        pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
        num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
        num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
        index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
        index_padded = index_padded.reshape(
            grid_t,
            num_windows_h,
            vit_merger_window_size,
            num_windows_w,
            vit_merger_window_size,
        )
        index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
            grid_t,
            num_windows_h * num_windows_w,
            vit_merger_window_size,
            vit_merger_window_size,
        )
        seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
        index_padded = index_padded.reshape(-1)
        index_new = index_padded[index_padded != -100]
        cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit
        cu_seqlens_tmp = cu_seqlens_tmp.to(dtype=torch.int32)
        cu_seqlens_tmp = torch.unique_consecutive(cu_seqlens_tmp)

        return index_new, cu_seqlens_tmp

    @lru_cache(maxsize=1024)  # noqa: B019
    def get_rope_by_thw(self, t, h, w):
        window_index_thw, cu_seqlens_window_thw = self.get_window_index_thw(t, h, w)
        cos_thw, sin_thw = self.rotary_pos_emb_thw(t, h, w)

        cos_thw = cos_thw[window_index_thw, :, :]
        cos_thw = cos_thw.flatten(start_dim=0, end_dim=1)
        sin_thw = sin_thw[window_index_thw, :, :]
        sin_thw = sin_thw.flatten(start_dim=0, end_dim=1)

        cu_seqlens_thw = torch.repeat_interleave(
            torch.tensor([h * w], dtype=torch.int32), t
        )
        return (
            cos_thw,
            sin_thw,
            window_index_thw,
            cu_seqlens_window_thw,
            cu_seqlens_thw,
        )

    def compute_attn_mask_seqlen(
        self,
        cu_seqlens: torch.Tensor,
    ) -> torch.Tensor:
        max_seqlen = torch.zeros([], device=cu_seqlens.device)
        if self.attn_backend in {
            AttentionBackendEnum.FLASH_ATTN,
            AttentionBackendEnum.ROCM_AITER_FA,
            AttentionBackendEnum.TRITON_ATTN,
        }:
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
        return max_seqlen

    @staticmethod
    def invert_permutation(perm: torch.Tensor) -> torch.Tensor:
        # building the inverse permutation in O(n) time
        inv = torch.empty_like(perm, pin_memory=is_pin_memory_available())
        inv[perm] = torch.arange(perm.numel(), device=perm.device, dtype=perm.dtype)
        return inv

    def prepare_encoder_metadata(
        self,
        grid_thw: list[list[int]],
        *,
        max_batch_size: int | None = None,
        max_frames_per_batch: int | None = None,
        max_window_seqs_per_batch: int | None = None,
        max_seqlen_override: int | None = None,
        max_seqlen_window_override: int | None = None,
        device: torch.device | None = None,
    ) -> dict[str, torch.Tensor]:
        """Compute encoder metadata from grid_thw.

        Shared by the eager forward path, CUDA graph capture, and
        CUDA graph replay to avoid duplicated implementation.

        Args:
            grid_thw: Grid configurations as list of [t, h, w].
            max_batch_size: If set, pad cu_seqlens to this size
                (needed for CUDA graph capture/replay).
            max_frames_per_batch: If set, overrides max_batch_size for
                cu_seqlens padding. For video inputs each item contributes
                T attention sequences (frames); this sizes the buffer to
                the total frame budget so video replays never overflow.
            max_window_seqs_per_batch: If set, pad cu_window_seqlens to this
                number of window sequences. This keeps cu_window_seqlens shape
                stable across capture/replay for CUDA graph safety.
            max_seqlen_override: If set, use this value for max_seqlen
                instead of computing from cu_seqlens (needed for CUDA
                graph capture to cover worst-case replay scenarios).
            max_seqlen_window_override: If set, use this value for
                window-attention max_seqlen instead of computing from
                cu_window_seqlens (needed for CUDA graph capture to
                cover worst-case replay scenarios).
            device: Device to place tensors on. Defaults to self.device.
        """

        if device is None:
            device = self.device
        metadata: dict[str, torch.Tensor] = {}

        # patchify
        rotary_pos_emb_cos = []
        rotary_pos_emb_sin = []
        window_index: list = []
        cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)]
        cu_seqlens: list = []

        window_index_id = 0
        cu_window_seqlens_last = 0
        for t, h, w in grid_thw:
            t, h, w = int(t), int(h), int(w)
            llm_h = h // self.spatial_merge_size
            llm_w = w // self.spatial_merge_size

            (
                cos_thw,
                sin_thw,
                window_index_thw,
                cu_seqlens_window_thw,
                cu_seqlens_thw,
            ) = self.get_rope_by_thw(t, h, w)

            window_index.append(window_index_thw + window_index_id)
            window_index_id += t * llm_h * llm_w

            cu_seqlens_window_thw = cu_seqlens_window_thw + cu_window_seqlens_last
            cu_window_seqlens_last = cu_seqlens_window_thw[-1]
            cu_window_seqlens.append(cu_seqlens_window_thw)

            rotary_pos_emb_cos.append(cos_thw)
            rotary_pos_emb_sin.append(sin_thw)

            cu_seqlens.append(cu_seqlens_thw)

        rotary_pos_emb_cos = torch.cat(rotary_pos_emb_cos)
        rotary_pos_emb_sin = torch.cat(rotary_pos_emb_sin)
        window_index = torch.cat(window_index)
        # compute reverse indices
        reverse_indices = self.invert_permutation(window_index)
        cu_window_seqlens = torch.cat(cu_window_seqlens)
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
        cu_seqlens = torch.cat(cu_seqlens)
        cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
        cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

        # Pad cu_seqlens to the required number of sequences.
        # For videos each item contributes T frames = T attention sequences,
        # so the total can exceed max_batch_size. max_frames_per_batch
        # overrides the pad target when set.
        pad_to = (
            max_frames_per_batch if max_frames_per_batch is not None else max_batch_size
        )
        if pad_to is not None:
            num_seqs = len(cu_seqlens) - 1
            if num_seqs < pad_to:
                cu_seqlens = torch.cat(
                    (
                        cu_seqlens,
                        torch.full(
                            (pad_to - num_seqs,),
                            cu_seqlens[-1],
                            dtype=cu_seqlens.dtype,
                            device=cu_seqlens.device,
                        ),
                    )
                )

        # Pad cu_window_seqlens to a stable number of window sequences.
        # Like cu_seqlens, we repeat the last cumulative offset so padded
        # entries represent empty sequences.
        if max_window_seqs_per_batch is not None:
            num_window_seqs = len(cu_window_seqlens) - 1
            if num_window_seqs < max_window_seqs_per_batch:
                cu_window_seqlens = torch.cat(
                    (
                        cu_window_seqlens,
                        torch.full(
                            (max_window_seqs_per_batch - num_window_seqs,),
                            cu_window_seqlens[-1],
                            dtype=cu_window_seqlens.dtype,
                            device=cu_window_seqlens.device,
                        ),
                    )
                )

        # transformers
        # pre-compute seqlens for window/full attn to reduce cuMemcpy operations
        if max_seqlen_override is None:
            max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
        else:
            max_seqlen_full = torch.tensor(max_seqlen_override, dtype=torch.int32)
        if max_seqlen_window_override is None:
            max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
        else:
            max_seqlen_window = torch.tensor(
                max_seqlen_window_override, dtype=torch.int32
            )

        cu_seqlens = cu_seqlens.to(device=device, non_blocking=True)
        cu_window_seqlens = cu_window_seqlens.to(device=device, non_blocking=True)
        rotary_pos_emb_cos = rotary_pos_emb_cos.to(device=device, non_blocking=True)
        rotary_pos_emb_sin = rotary_pos_emb_sin.to(device=device, non_blocking=True)
        window_index = window_index.to(device=device, non_blocking=True)
        reverse_indices = reverse_indices.to(device=device, non_blocking=True)

        metadata["rotary_pos_emb_cos"] = rotary_pos_emb_cos
        metadata["rotary_pos_emb_sin"] = rotary_pos_emb_sin
        metadata["window_index"] = window_index
        metadata["reverse_indices"] = reverse_indices
        metadata["cu_seqlens"] = cu_seqlens
        metadata["cu_window_seqlens"] = cu_window_seqlens
        metadata["max_seqlen_full"] = max_seqlen_full
        metadata["max_seqlen_window"] = max_seqlen_window

        return metadata

    def forward(
        self,
        x: torch.Tensor,
        grid_thw: list[list[int]],
        *,
        encoder_metadata: dict[str, torch.Tensor] | None = None,
    ) -> torch.Tensor:
        hidden_states = x.to(device=self.device, dtype=self.dtype)
        hidden_states = self.patch_embed(hidden_states)

        seq_len = hidden_states.shape[0]
        if encoder_metadata is None:
            encoder_metadata = self.prepare_encoder_metadata(grid_thw)

        rotary_pos_emb_cos = encoder_metadata["rotary_pos_emb_cos"]
        rotary_pos_emb_sin = encoder_metadata["rotary_pos_emb_sin"]
        window_index = encoder_metadata["window_index"]
        reverse_indices = encoder_metadata["reverse_indices"]
        cu_seqlens = encoder_metadata["cu_seqlens"]
        cu_window_seqlens = encoder_metadata["cu_window_seqlens"]
        max_seqlen_full = encoder_metadata["max_seqlen_full"]
        max_seqlen_window = encoder_metadata["max_seqlen_window"]

        hidden_states = hidden_states.reshape(
            seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1
        )
        hidden_states = hidden_states[window_index, :, :]
        hidden_states = hidden_states.reshape(seq_len, -1)

        hidden_states = hidden_states.unsqueeze(1)

        for layer_num, blk in enumerate(self.blocks):
            if layer_num in self.fullatt_block_indexes:
                cu_seqlens_now = cu_seqlens
                max_seqlen_now = max_seqlen_full
            else:
                cu_seqlens_now = cu_window_seqlens
                max_seqlen_now = max_seqlen_window

            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens_now,
                rotary_pos_emb_cos=rotary_pos_emb_cos,
                rotary_pos_emb_sin=rotary_pos_emb_sin,
                max_seqlen=max_seqlen_now,
            )

        # For Qwen2.5-VL-3B, float16 will overflow at last block
        # for long visual tokens sequences.
        if hidden_states.dtype == torch.float16:
            hidden_states = cast_overflow_tensors(hidden_states)

        # adapter
        hidden_states = self.merger(hidden_states)
        hidden_states = hidden_states[reverse_indices, :]
        return hidden_states

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("attn.qkv.", "attn.q.", "q"),
            ("attn.qkv.", "attn.k.", "k"),
            ("attn.qkv.", "attn.v.", "v"),
            ("mlp.gate_up_proj.", "mlp.gate_proj.", 0),
            ("mlp.gate_up_proj.", "mlp.up_proj.", 1),
        ]
        params_dict = dict(self.named_parameters(remove_duplicate=False))
        loaded_params: set[str] = set()

        for name, loaded_weight in weights:
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

prepare_encoder_metadata

prepare_encoder_metadata(
    grid_thw: list[list[int]],
    *,
    max_batch_size: int | None = None,
    max_frames_per_batch: int | None = None,
    max_window_seqs_per_batch: int | None = None,
    max_seqlen_override: int | None = None,
    max_seqlen_window_override: int | None = None,
    device: device | None = None,
) -> dict[str, Tensor]

Compute encoder metadata from grid_thw.

Shared by the eager forward path, CUDA graph capture, and CUDA graph replay to avoid duplicated implementation.

Parameters:

Name Type Description Default
grid_thw list[list[int]]

Grid configurations as list of [t, h, w].

required
max_batch_size int | None

If set, pad cu_seqlens to this size (needed for CUDA graph capture/replay).

None
max_frames_per_batch int | None

If set, overrides max_batch_size for cu_seqlens padding. For video inputs each item contributes T attention sequences (frames); this sizes the buffer to the total frame budget so video replays never overflow.

None
max_window_seqs_per_batch int | None

If set, pad cu_window_seqlens to this number of window sequences. This keeps cu_window_seqlens shape stable across capture/replay for CUDA graph safety.

None
max_seqlen_override int | None

If set, use this value for max_seqlen instead of computing from cu_seqlens (needed for CUDA graph capture to cover worst-case replay scenarios).

None
max_seqlen_window_override int | None

If set, use this value for window-attention max_seqlen instead of computing from cu_window_seqlens (needed for CUDA graph capture to cover worst-case replay scenarios).

None
device device | None

Device to place tensors on. Defaults to self.device.

None
Source code in vllm/model_executor/models/qwen2_5_vl.py
def prepare_encoder_metadata(
    self,
    grid_thw: list[list[int]],
    *,
    max_batch_size: int | None = None,
    max_frames_per_batch: int | None = None,
    max_window_seqs_per_batch: int | None = None,
    max_seqlen_override: int | None = None,
    max_seqlen_window_override: int | None = None,
    device: torch.device | None = None,
) -> dict[str, torch.Tensor]:
    """Compute encoder metadata from grid_thw.

    Shared by the eager forward path, CUDA graph capture, and
    CUDA graph replay to avoid duplicated implementation.

    Args:
        grid_thw: Grid configurations as list of [t, h, w].
        max_batch_size: If set, pad cu_seqlens to this size
            (needed for CUDA graph capture/replay).
        max_frames_per_batch: If set, overrides max_batch_size for
            cu_seqlens padding. For video inputs each item contributes
            T attention sequences (frames); this sizes the buffer to
            the total frame budget so video replays never overflow.
        max_window_seqs_per_batch: If set, pad cu_window_seqlens to this
            number of window sequences. This keeps cu_window_seqlens shape
            stable across capture/replay for CUDA graph safety.
        max_seqlen_override: If set, use this value for max_seqlen
            instead of computing from cu_seqlens (needed for CUDA
            graph capture to cover worst-case replay scenarios).
        max_seqlen_window_override: If set, use this value for
            window-attention max_seqlen instead of computing from
            cu_window_seqlens (needed for CUDA graph capture to
            cover worst-case replay scenarios).
        device: Device to place tensors on. Defaults to self.device.
    """

    if device is None:
        device = self.device
    metadata: dict[str, torch.Tensor] = {}

    # patchify
    rotary_pos_emb_cos = []
    rotary_pos_emb_sin = []
    window_index: list = []
    cu_window_seqlens: list = [torch.tensor([0], dtype=torch.int32)]
    cu_seqlens: list = []

    window_index_id = 0
    cu_window_seqlens_last = 0
    for t, h, w in grid_thw:
        t, h, w = int(t), int(h), int(w)
        llm_h = h // self.spatial_merge_size
        llm_w = w // self.spatial_merge_size

        (
            cos_thw,
            sin_thw,
            window_index_thw,
            cu_seqlens_window_thw,
            cu_seqlens_thw,
        ) = self.get_rope_by_thw(t, h, w)

        window_index.append(window_index_thw + window_index_id)
        window_index_id += t * llm_h * llm_w

        cu_seqlens_window_thw = cu_seqlens_window_thw + cu_window_seqlens_last
        cu_window_seqlens_last = cu_seqlens_window_thw[-1]
        cu_window_seqlens.append(cu_seqlens_window_thw)

        rotary_pos_emb_cos.append(cos_thw)
        rotary_pos_emb_sin.append(sin_thw)

        cu_seqlens.append(cu_seqlens_thw)

    rotary_pos_emb_cos = torch.cat(rotary_pos_emb_cos)
    rotary_pos_emb_sin = torch.cat(rotary_pos_emb_sin)
    window_index = torch.cat(window_index)
    # compute reverse indices
    reverse_indices = self.invert_permutation(window_index)
    cu_window_seqlens = torch.cat(cu_window_seqlens)
    cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
    cu_seqlens = torch.cat(cu_seqlens)
    cu_seqlens = torch.cumsum(cu_seqlens, dim=0, dtype=torch.int32)
    cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)

    # Pad cu_seqlens to the required number of sequences.
    # For videos each item contributes T frames = T attention sequences,
    # so the total can exceed max_batch_size. max_frames_per_batch
    # overrides the pad target when set.
    pad_to = (
        max_frames_per_batch if max_frames_per_batch is not None else max_batch_size
    )
    if pad_to is not None:
        num_seqs = len(cu_seqlens) - 1
        if num_seqs < pad_to:
            cu_seqlens = torch.cat(
                (
                    cu_seqlens,
                    torch.full(
                        (pad_to - num_seqs,),
                        cu_seqlens[-1],
                        dtype=cu_seqlens.dtype,
                        device=cu_seqlens.device,
                    ),
                )
            )

    # Pad cu_window_seqlens to a stable number of window sequences.
    # Like cu_seqlens, we repeat the last cumulative offset so padded
    # entries represent empty sequences.
    if max_window_seqs_per_batch is not None:
        num_window_seqs = len(cu_window_seqlens) - 1
        if num_window_seqs < max_window_seqs_per_batch:
            cu_window_seqlens = torch.cat(
                (
                    cu_window_seqlens,
                    torch.full(
                        (max_window_seqs_per_batch - num_window_seqs,),
                        cu_window_seqlens[-1],
                        dtype=cu_window_seqlens.dtype,
                        device=cu_window_seqlens.device,
                    ),
                )
            )

    # transformers
    # pre-compute seqlens for window/full attn to reduce cuMemcpy operations
    if max_seqlen_override is None:
        max_seqlen_full = self.compute_attn_mask_seqlen(cu_seqlens)
    else:
        max_seqlen_full = torch.tensor(max_seqlen_override, dtype=torch.int32)
    if max_seqlen_window_override is None:
        max_seqlen_window = self.compute_attn_mask_seqlen(cu_window_seqlens)
    else:
        max_seqlen_window = torch.tensor(
            max_seqlen_window_override, dtype=torch.int32
        )

    cu_seqlens = cu_seqlens.to(device=device, non_blocking=True)
    cu_window_seqlens = cu_window_seqlens.to(device=device, non_blocking=True)
    rotary_pos_emb_cos = rotary_pos_emb_cos.to(device=device, non_blocking=True)
    rotary_pos_emb_sin = rotary_pos_emb_sin.to(device=device, non_blocking=True)
    window_index = window_index.to(device=device, non_blocking=True)
    reverse_indices = reverse_indices.to(device=device, non_blocking=True)

    metadata["rotary_pos_emb_cos"] = rotary_pos_emb_cos
    metadata["rotary_pos_emb_sin"] = rotary_pos_emb_sin
    metadata["window_index"] = window_index
    metadata["reverse_indices"] = reverse_indices
    metadata["cu_seqlens"] = cu_seqlens
    metadata["cu_window_seqlens"] = cu_window_seqlens
    metadata["max_seqlen_full"] = max_seqlen_full
    metadata["max_seqlen_window"] = max_seqlen_window

    return metadata