from collections.abc import Iterable
from functools import lru_cache, reduce
from typing import Any, TypeVar, cast
import pytest
from grunnur import API, Context, Device, DeviceFilter, Platform, PlatformFilter, all_api_ids
[docs]
def pytest_addoption(parser: pytest.Parser) -> None:
"""
Adds the following command-line options:
* ``--api``: select a specific API to test (out of returned by :py:func:`grunnur.all_api_ids`).
* ``--platform-include-mask``: run tests only on platforms whose names matches the mask.
* ``--platform-exclude-mask``: exclude platforms whose names matches the mask from the tests.
* ``--device-include-mask``: run tests only on devices whose names matches the mask.
* ``--device-exclude-mask``: exclude devices whose names matches the mask from the tests.
* ``--include-duplicate-devices``: if there are devices with the same name in the platform,
run tests on all of them.
* ``--include-pure-parallel-devices``: include pure parallel devices in the tests
(that is, those not supporting synchronization within a block/work group).
"""
api_shortcuts = [api_id.shortcut for api_id in all_api_ids()]
parser.addoption(
"--api",
action="store",
help="GPGPU API: " + "/".join(api_shortcuts) + " (or all available if not given)",
default=None,
choices=api_shortcuts,
)
parser.addoption(
"--platform-include-mask",
action="append",
help="Run tests on matching platforms only",
default=[],
)
parser.addoption(
"--platform-exclude-mask",
action="append",
help="Run tests on matching platforms only",
default=[],
)
parser.addoption(
"--device-include-mask",
action="append",
help="Run tests on matching devices only",
default=[],
)
parser.addoption(
"--device-exclude-mask",
action="append",
help="Run tests on matching devices only",
default=[],
)
parser.addoption(
"--include-duplicate-devices",
action="store_true",
help="Run tests on all available devices and not only on uniquely named ones",
default=False,
)
parser.addoption(
"--include-pure-parallel-devices",
action="store_true",
help="Include pure parallel devices (not supporting synchronization within a work group)",
default=False,
)
[docs]
@lru_cache
def get_apis(config: pytest.Config) -> list[API]:
"""Returns the list of APIs filtered by the test configuration."""
return API.all_by_shortcut(config.option.api)
_T = TypeVar("_T")
def _concatenate(lists: Iterable[list[_T]]) -> list[_T]:
return reduce(lambda x, y: x + y, lists, [])
@lru_cache
def _get_device_sets(
config: pytest.Config, unique_devices_only_override: bool | None = None
) -> list[list[Device]]:
if unique_devices_only_override is not None:
unique_devices_only = unique_devices_only_override
else:
unique_devices_only = not config.option.include_duplicate_devices
platforms = get_platforms(config)
return [
Device.all_filtered(
platform,
DeviceFilter(
include_masks=config.option.device_include_mask,
exclude_masks=config.option.device_exclude_mask,
unique_only=unique_devices_only,
exclude_pure_parallel=not config.option.include_pure_parallel_devices,
),
)
for platform in platforms
]
[docs]
@lru_cache
def get_devices(config: pytest.Config) -> list[Device]:
"""
Returns the list of devices filtered by the test configuration
(concatenated for all filtered platforms and APIs).
"""
return _concatenate(_get_device_sets(config))
[docs]
@lru_cache
def get_multi_device_sets(config: pytest.Config) -> list[list[Device]]:
"""
Returns a list where each element is a list with two or more devices
belonging to the same API and platform, where APIs, platforms, and devices
are filtered by the test configuration.
"""
device_sets = _get_device_sets(config, unique_devices_only_override=False)
return [device_set for device_set in device_sets if len(device_set) > 1]
[docs]
@pytest.fixture()
def api(request: pytest.FixtureRequest) -> API:
"""Yields the elements of the return value of :py:func:`~pytest_grunnur.get_apis`."""
# Set in `pytest_generate_tests()`
return cast(API, request.param)
[docs]
@pytest.fixture()
def device(request: pytest.FixtureRequest) -> Device:
"""Yields the elements of the return value of :py:func:`~pytest_grunnur.get_devices`."""
# Set in `pytest_generate_tests()`
return cast(Device, request.param)
[docs]
@pytest.fixture()
def context(device: Device) -> Context:
"""
A single-device context for each device yielded by
:py:func:`~pytest_grunnur.plugin.device`.
"""
return Context.from_devices([device])
[docs]
@pytest.fixture()
def some_device(request: pytest.FixtureRequest) -> Device:
"""Yields one element of the return value of :py:func:`~pytest_grunnur.get_devices`."""
# Set in `pytest_generate_tests()`
return cast(Device, request.param)
[docs]
@pytest.fixture()
def some_context(some_device: Device) -> Context:
"""
A single-device context initialized with the return value of
:py:func:`~pytest_grunnur.plugin.some_device`.
"""
return Context.from_devices([some_device])
[docs]
@pytest.fixture()
def multi_device_set(request: pytest.FixtureRequest) -> list[Device]:
"""
Yields the elements of the return value of
:py:func:`~pytest_grunnur.get_multi_device_sets`.
"""
# Set in `pytest_generate_tests()`
return cast(list[Device], request.param)
[docs]
@pytest.fixture()
def multi_device_context(multi_device_set: list[Device]) -> Context:
"""
A multi-device context for each device set yielded by
:py:func:`~pytest_grunnur.plugin.multi_device_set`.
"""
return Context.from_devices(multi_device_set)
[docs]
def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
"""
Seeds the parameters for the fixtures provided by this plugin
(see the fixture list for details).
"""
apis = get_apis(metafunc.config)
platforms = get_platforms(metafunc.config)
devices = get_devices(metafunc.config)
fixtures: list[tuple[str, list[Any]]] = [
("api", apis),
("platform", platforms),
("device", devices),
]
for name, vals in fixtures:
if name in metafunc.fixturenames:
metafunc.parametrize(
name,
vals,
ids=["no_" + name] if len(vals) == 0 else lambda obj: cast(str, obj.shortcut),
indirect=True,
)
if "some_device" in metafunc.fixturenames:
metafunc.parametrize(
"some_device",
devices if len(devices) == 0 else [devices[0]],
ids=["no_device"] if len(devices) == 0 else lambda device: cast(str, device.shortcut),
indirect=True,
)
if "multi_device_set" in metafunc.fixturenames:
device_sets = get_multi_device_sets(metafunc.config)
ids = ["+".join(device.shortcut for device in device_set) for device_set in device_sets]
metafunc.parametrize(
"multi_device_set",
device_sets,
ids=["no_multi_device"] if len(device_sets) == 0 else ids,
indirect=True,
)