# Copyright 2013-2024 Lawrence Livermore National Security, LLC and other
# Spack Project Developers. See the top-level COPYRIGHT file for details.
#
# SPDX-License-Identifier: (Apache-2.0 OR MIT)
"""Classes and functions to manage providers of virtual dependencies"""
from typing import Dict, List, Optional, Set
import spack.error
import spack.spec
import spack.util.spack_json as sjson
class _IndexBase:
#: This is a dict of dicts used for finding providers of particular
#: virtual dependencies. The dict of dicts looks like:
#:
#: { vpkg name :
#: { full vpkg spec : set(packages providing spec) } }
#:
#: Callers can use this to first find which packages provide a vpkg,
#: then find a matching full spec. e.g., in this scenario:
#:
#: { 'mpi' :
#: { mpi@:1.1 : set([mpich]),
#: mpi@:2.3 : set([mpich2@1.9:]) } }
#:
#: Calling providers_for(spec) will find specs that provide a
#: matching implementation of MPI. Derived class need to construct
#: this attribute according to the semantics above.
providers: Dict[str, Dict[str, Set[str]]]
def providers_for(self, virtual_spec):
"""Return a list of specs of all packages that provide virtual
packages with the supplied spec.
Args:
virtual_spec: virtual spec to be provided
"""
result = set()
# Allow string names to be passed as input, as well as specs
if isinstance(virtual_spec, str):
virtual_spec = spack.spec.Spec(virtual_spec)
# Add all the providers that satisfy the vpkg spec.
if virtual_spec.name in self.providers:
for p_spec, spec_set in self.providers[virtual_spec.name].items():
if p_spec.intersects(virtual_spec, deps=False):
result.update(spec_set)
# Return providers in order. Defensively copy.
return sorted(s.copy() for s in result)
def __contains__(self, name):
return name in self.providers
def __eq__(self, other):
return self.providers == other.providers
def _transform(self, transform_fun, out_mapping_type=dict):
"""Transform this provider index dictionary and return it.
Args:
transform_fun: transform_fun takes a (vpkg, pset) mapping and runs
it on each pair in nested dicts.
out_mapping_type: type to be used internally on the
transformed (vpkg, pset)
Returns:
Transformed mapping
"""
return _transform(self.providers, transform_fun, out_mapping_type)
def __str__(self):
return str(self.providers)
def __repr__(self):
return repr(self.providers)
[docs]
class ProviderIndex(_IndexBase):
def __init__(
self,
repository: "spack.repo.RepoType",
specs: Optional[List["spack.spec.Spec"]] = None,
restrict: bool = False,
):
"""Provider index based on a single mapping of providers.
Args:
specs: if provided, will call update on each
single spec to initialize this provider index.
restrict: "restricts" values to the verbatim input specs; do not
pre-apply package's constraints.
TODO: rename this. It is intended to keep things as broad
TODO: as possible without overly restricting results, so it is
TODO: not the best name.
"""
self.repository = repository
self.restrict = restrict
self.providers = {}
specs = specs or []
for spec in specs:
if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec)
if self.repository.is_virtual_safe(spec.name):
continue
self.update(spec)
[docs]
def update(self, spec):
"""Update the provider index with additional virtual specs.
Args:
spec: spec potentially providing additional virtual specs
"""
if not isinstance(spec, spack.spec.Spec):
spec = spack.spec.Spec(spec)
if not spec.name:
# Empty specs do not have a package
return
msg = "cannot update an index passing the virtual spec '{}'".format(spec.name)
assert not self.repository.is_virtual_safe(spec.name), msg
pkg_provided = self.repository.get_pkg_class(spec.name).provided
for provider_spec_readonly, provided_specs in pkg_provided.items():
for provided_spec in provided_specs:
# TODO: fix this comment.
# We want satisfaction other than flags
provider_spec = provider_spec_readonly.copy()
provider_spec.compiler_flags = spec.compiler_flags.copy()
if spec.intersects(provider_spec, deps=False):
provided_name = provided_spec.name
provider_map = self.providers.setdefault(provided_name, {})
if provided_spec not in provider_map:
provider_map[provided_spec] = set()
if self.restrict:
provider_set = provider_map[provided_spec]
# If this package existed in the index before,
# need to take the old versions out, as they're
# now more constrained.
old = set([s for s in provider_set if s.name == spec.name])
provider_set.difference_update(old)
# Now add the new version.
provider_set.add(spec)
else:
# Before putting the spec in the map, constrain
# it so that it provides what was asked for.
constrained = spec.copy()
constrained.constrain(provider_spec)
provider_map[provided_spec].add(constrained)
[docs]
def to_json(self, stream=None):
"""Dump a JSON representation of this object.
Args:
stream: stream where to dump
"""
provider_list = self._transform(
lambda vpkg, pset: [vpkg.to_node_dict(), [p.to_node_dict() for p in pset]], list
)
sjson.dump({"provider_index": {"providers": provider_list}}, stream)
[docs]
def merge(self, other):
"""Merge another provider index into this one.
Args:
other (ProviderIndex): provider index to be merged
"""
other = other.copy() # defensive copy.
for pkg in other.providers:
if pkg not in self.providers:
self.providers[pkg] = other.providers[pkg]
continue
spdict, opdict = self.providers[pkg], other.providers[pkg]
for provided_spec in opdict:
if provided_spec not in spdict:
spdict[provided_spec] = opdict[provided_spec]
continue
spdict[provided_spec] = spdict[provided_spec].union(opdict[provided_spec])
[docs]
def remove_provider(self, pkg_name):
"""Remove a provider from the ProviderIndex."""
empty_pkg_dict = []
for pkg, pkg_dict in self.providers.items():
empty_pset = []
for provided, pset in pkg_dict.items():
same_name = set(p for p in pset if p.fullname == pkg_name)
pset.difference_update(same_name)
if not pset:
empty_pset.append(provided)
for provided in empty_pset:
del pkg_dict[provided]
if not pkg_dict:
empty_pkg_dict.append(pkg)
for pkg in empty_pkg_dict:
del self.providers[pkg]
[docs]
def copy(self):
"""Return a deep copy of this index."""
clone = ProviderIndex(repository=self.repository)
clone.providers = self._transform(lambda vpkg, pset: (vpkg, set((p.copy() for p in pset))))
return clone
[docs]
@staticmethod
def from_json(stream, repository):
"""Construct a provider index from its JSON representation.
Args:
stream: stream where to read from the JSON data
"""
data = sjson.load(stream)
if not isinstance(data, dict):
raise ProviderIndexError("JSON ProviderIndex data was not a dict.")
if "provider_index" not in data:
raise ProviderIndexError("YAML ProviderIndex does not start with 'provider_index'")
index = ProviderIndex(repository=repository)
providers = data["provider_index"]["providers"]
index.providers = _transform(
providers,
lambda vpkg, plist: (
spack.spec.SpecfileV4.from_node_dict(vpkg),
set(spack.spec.SpecfileV4.from_node_dict(p) for p in plist),
),
)
return index
def _transform(providers, transform_fun, out_mapping_type=dict):
"""Syntactic sugar for transforming a providers dict.
Args:
providers: provider dictionary
transform_fun: transform_fun takes a (vpkg, pset) mapping and runs
it on each pair in nested dicts.
out_mapping_type: type to be used internally on the
transformed (vpkg, pset)
Returns:
Transformed mapping
"""
def mapiter(mappings):
if isinstance(mappings, dict):
return mappings.items()
else:
return iter(mappings)
return dict(
(name, out_mapping_type([transform_fun(vpkg, pset) for vpkg, pset in mapiter(mappings)]))
for name, mappings in providers.items()
)
[docs]
class ProviderIndexError(spack.error.SpackError):
"""Raised when there is a problem with a ProviderIndex."""