from functools import lru_cache
from typing import Iterable, TypeVar, Optional, List, Iterator, Any, Tuple, cast
import pytest
from grunnur import API, all_api_ids, Platform, PlatformFilter, Device, DeviceFilter, Context
[docs]def pytest_addoption(parser: pytest.Parser) -> None:
"""
Adds the following command-line options:
* ``--api``: select a specific API to test (out of returned by :py:func:`grunnur.all_api_ids`).
* ``--platform-include-mask``: run tests only on platforms whose names matches the mask.
* ``--platform-exclude-mask``: exclude platforms whose names matches the mask from the tests.
* ``--device-include-mask``: run tests only on devices whose names matches the mask.
* ``--device-exclude-mask``: exclude devices whose names matches the mask from the tests.
* ``--include-duplicate-devices``: if there are devices with the same name in the platform,
run tests on all of them.
* ``--include-pure-parallel-devices``: include pure parallel devices in the tests
(that is, those not supporting synchronization within a block/work group).
"""
api_shortcuts = [api_id.shortcut for api_id in all_api_ids()]
parser.addoption(
"--api",
action="store",
help="GPGPU API: " + "/".join(api_shortcuts) + " (or all available if not given)",
default=None,
choices=api_shortcuts,
)
parser.addoption(
"--platform-include-mask",
action="append",
help="Run tests on matching platforms only",
default=[],
)
parser.addoption(
"--platform-exclude-mask",
action="append",
help="Run tests on matching platforms only",
default=[],
)
parser.addoption(
"--device-include-mask",
action="append",
help="Run tests on matching devices only",
default=[],
)
parser.addoption(
"--device-exclude-mask",
action="append",
help="Run tests on matching devices only",
default=[],
)
parser.addoption(
"--include-duplicate-devices",
action="store_true",
help="Run tests on all available devices and not only on uniquely named ones",
default=False,
)
parser.addoption(
"--include-pure-parallel-devices",
action="store_true",
help="Include pure parallel devices (not supporting synchronization within a work group)",
default=False,
)
[docs]@lru_cache()
def get_apis(config: pytest.Config) -> List[API]:
"""
Returns the list of APIs filtered by the test configuration.
"""
return API.all_by_shortcut(config.option.api)
_T = TypeVar("_T")
def concatenate(lists: Iterable[List[_T]]) -> List[_T]:
return sum(lists, [])
@lru_cache()
def get_device_sets(
config: pytest.Config, unique_devices_only_override: Optional[bool] = None
) -> List[List[Device]]:
if unique_devices_only_override is not None:
unique_devices_only = unique_devices_only_override
else:
unique_devices_only = not config.option.include_duplicate_devices
platforms = get_platforms(config)
return [
Device.all_filtered(
platform,
DeviceFilter(
include_masks=config.option.device_include_mask,
exclude_masks=config.option.device_exclude_mask,
unique_only=unique_devices_only,
exclude_pure_parallel=not config.option.include_pure_parallel_devices,
),
)
for platform in platforms
]
[docs]@lru_cache()
def get_devices(config: pytest.Config) -> List[Device]:
"""
Returns the list of devices filtered by the test configuration
(concatenated for all filtered platforms and APIs).
"""
return concatenate(get_device_sets(config))
[docs]@lru_cache()
def get_multi_device_sets(config: pytest.Config) -> List[List[Device]]:
"""
Returns a list where each element is a list with two or more devices
belonging to the same API and platform, where APIs, platforms, and devices
are filtered by the test configuration.
"""
device_sets = get_device_sets(config, unique_devices_only_override=False)
return [device_set for device_set in device_sets if len(device_set) > 1]
[docs]@pytest.fixture
def api(request: pytest.FixtureRequest) -> Iterator[API]:
"""
Yields the elements of the return value of :py:func:`~pytest_grunnur.get_apis`.
"""
yield request.param
[docs]@pytest.fixture
def device(request: pytest.FixtureRequest) -> Iterator[Device]:
"""
Yields the elements of the return value of :py:func:`~pytest_grunnur.get_devices`.
"""
yield request.param
[docs]@pytest.fixture
def context(device: Device) -> Iterator[Context]:
"""
A single-device context for each device yielded by :py:func:`~pytest_grunnur.plugin.device`.
"""
yield Context.from_devices([device])
[docs]@pytest.fixture
def some_device(request: pytest.FixtureRequest) -> Iterator[Device]:
"""
Yields one element of the return value of :py:func:`~pytest_grunnur.get_devices`.
"""
yield request.param
[docs]@pytest.fixture
def some_context(some_device: Device) -> Iterator[Context]:
"""
A single-device context initialized with the return value of
:py:func:`~pytest_grunnur.plugin.some_device`.
"""
yield Context.from_devices([some_device])
[docs]@pytest.fixture
def multi_device_set(request: pytest.FixtureRequest) -> Iterator[List[Device]]:
"""
Yields the elements of the return value of :py:func:`~pytest_grunnur.get_multi_device_sets`.
"""
yield request.param
[docs]@pytest.fixture
def multi_device_context(multi_device_set: List[Device]) -> Iterator[Context]:
"""
A multi-device context for each device set yielded by
:py:func:`~pytest_grunnur.plugin.multi_device_set`.
"""
yield Context.from_devices(multi_device_set)
[docs]def pytest_generate_tests(metafunc: pytest.Metafunc) -> None:
"""
Seeds the parameters for the fixtures provided by this plugin
(see the fixture list for details).
"""
apis = get_apis(metafunc.config)
platforms = get_platforms(metafunc.config)
devices = get_devices(metafunc.config)
fixtures: List[Tuple[str, List[Any]]] = [
("api", apis),
("platform", platforms),
("device", devices),
]
for name, vals in fixtures:
if name in metafunc.fixturenames:
metafunc.parametrize(
name,
vals,
ids=["no_" + name] if len(vals) == 0 else lambda obj: cast(str, obj.shortcut),
indirect=True,
)
if "some_device" in metafunc.fixturenames:
metafunc.parametrize(
"some_device",
devices if len(devices) == 0 else [devices[0]],
ids=["no_device"] if len(devices) == 0 else lambda device: cast(str, device.shortcut),
indirect=True,
)
if "multi_device_set" in metafunc.fixturenames:
device_sets = get_multi_device_sets(metafunc.config)
ids = ["+".join(device.shortcut for device in device_set) for device_set in device_sets]
metafunc.parametrize(
"multi_device_set",
device_sets,
ids=["no_multi_device"] if len(device_sets) == 0 else ids,
indirect=True,
)