from __future__ import annotations import pytest from pdf2md.gpu import GpuInfo, parse_nvidia_smi_gpus, select_gpu def test_parse_nvidia_smi_output_with_one_rtx_gpu() -> None: gpus = parse_nvidia_smi_gpus("0, NVIDIA GeForce RTX 4090, 24564, 577.00\n") assert gpus == ( GpuInfo(index=0, name="NVIDIA GeForce RTX 4090", memory_total_mib=24564, driver_version="577.00"), ) assert gpus[0].pre_turing_risk is False def test_parse_nvidia_smi_output_with_multiple_gpus_and_mib_suffix() -> None: gpus = parse_nvidia_smi_gpus( "0, NVIDIA GeForce GTX 1070 Ti, 8192 MiB, 577.00\n" "1, NVIDIA RTX A5000, 24564 MiB, 577.00\n" ) assert [gpu.index for gpu in gpus] == [0, 1] assert [gpu.memory_total_mib for gpu in gpus] == [8192, 24564] assert gpus[0].pre_turing_risk is True assert gpus[1].pre_turing_risk is False def test_parse_nvidia_smi_output_ignores_blank_lines_and_rejects_malformed_memory() -> None: with pytest.raises(ValueError, match="memory"): parse_nvidia_smi_gpus("\n0, NVIDIA RTX 4090, not-memory, 577.00\n") def test_select_gpu_auto_chooses_largest_vram_gpu() -> None: gpus = ( GpuInfo(index=0, name="NVIDIA RTX 4060", memory_total_mib=8192, driver_version="577.00"), GpuInfo(index=1, name="NVIDIA RTX 4090", memory_total_mib=24564, driver_version="577.00"), ) selected = select_gpu(gpus, "auto") assert selected.gpu == gpus[1] assert selected.cuda_device == "cuda:1" def test_select_gpu_accepts_cuda_and_numeric_requests() -> None: gpus = ( GpuInfo(index=0, name="NVIDIA RTX 4060", memory_total_mib=8192, driver_version="577.00"), GpuInfo(index=1, name="NVIDIA RTX 4090", memory_total_mib=24564, driver_version="577.00"), ) assert select_gpu(gpus, "cuda:1").gpu == gpus[1] assert select_gpu(gpus, "1").cuda_device == "cuda:1" def test_select_gpu_errors_when_requested_gpu_is_absent() -> None: gpus = (GpuInfo(index=0, name="NVIDIA RTX 4060", memory_total_mib=8192, driver_version="577.00"),) with pytest.raises(ValueError, match="not visible"): select_gpu(gpus, "cuda:1") def test_select_gpu_auto_errors_without_visible_gpus() -> None: with pytest.raises(ValueError, match="No visible NVIDIA GPU"): select_gpu((), "auto")