from typing import Iterator, Tuple, List
import os
import json
import urllib
import urllib.parse
import logging
import warnings
import shutil
import glob
import git
import yaml
from .build import BuildManager
from ..exceptions import NameInUseError, BadManifestFile
from ..compiler import Compiler
from ..core import TestSuite, Bug, Language, BuildInstructions, \
CoverageInstructions, Tool, Source, SourceContents, RemoteSource, \
LocalSource
logger = logging.getLogger(__name__) # type: logging.Logger
logger.setLevel(logging.DEBUG)
__all__ = ['SourceManager']
[docs]class SourceManager(object):
"""
TODO: we *could* cache the contents of all of the sources to disk, avoiding
the need to scan for them at startup. Although that might be cool, it
seems like overengineering and may create compatibility headaches in
the future.
"""
def __init__(self, installation: 'BugZoo') -> None:
self.__installation = installation
self.__path = os.path.join(installation.path, 'sources')
# TODO
self.__registry_fn = os.path.join(self.__path, 'registry.yml')
self.__sources = {}
self.__contents = {}
self.refresh()
[docs] def __iter__(self) -> Iterator[Source]:
"""
Returns an iterator over the sources registered with this server.
"""
return self.__sources.values().__iter__()
[docs] def __getitem__(self, name: str) -> Source:
"""
Attempts to fetch the description of a given source.
Parameters:
name: the name of the source.
Returns:
a description of the source.
Raises:
KeyError: if no source is found with the given name.
"""
return self.__sources[name]
def __delitem__(self, name: str) -> None:
"""
See `remove`.
"""
return self.remove(self[name])
def refresh(self) -> None:
"""
Reloads all sources that are registered with this server.
"""
logger.info('refreshing sources')
for source in list(self):
self.unload(source)
if not os.path.exists(self.__registry_fn):
return
# TODO add version
with open(self.__registry_fn, 'r') as f:
registry = yaml.safe_load(f)
assert isinstance(registry, list)
for source_description in registry:
source = Source.from_dict(source_description)
self.load(source)
logger.info('refreshed sources')
[docs] def update(self) -> None:
"""
Ensures that all remote sources are up-to-date.
"""
for source_old in self:
if isinstance(source_old, RemoteSource):
repo = git.Repo(source_old.location)
origin = repo.remotes.origin
origin.pull()
sha = repo.head.object.hexsha
version = repo.git.rev_parse(sha, short=8)
if version != source_old.version:
source_new = RemoteSource(source_old.name,
source_old.location,
source_old.url,
version)
logger.info("updated source: %s [%s -> %s]", source_old.name,
source_old.version,
source_new.version)
self.load(source_new)
else:
logger.debug("no updates for source: %s", source_old.name)
# write to disk
# TODO local directory may be corrupted if program terminates between
# repo being updated and registry being saved; could add a "fix"
# command to recalculate versions for remote sources
self.save()
def save(self) -> None:
"""
Saves the contents of the source manager to disk.
"""
logger.info('saving registry to: %s', self.__registry_fn)
d = [s.to_dict() for s in self]
os.makedirs(self.__path, exist_ok=True)
with open(self.__registry_fn, 'w') as f:
yaml.dump(d, f, indent=2, default_flow_style=False)
logger.info('saved registry to: %s', self.__registry_fn)
def unload(self, source: Source) -> None:
"""
Unloads a registered source, causing all of its associated bugs, tools,
and blueprints to also be unloaded. If the given source is not loaded,
this function will do nothing.
"""
logger.info('unloading source: %s', source.name)
try:
contents = self.contents(source)
del self.__contents[source.name]
del self.__sources[source.name]
for name in contents.bugs:
bug = self.__installation.bugs[name]
self.__installation.bugs.remove(bug)
for name in contents.blueprints:
blueprint = self.__installation.build[name]
self.__installation.build.remove(blueprint)
for name in contents.tools:
tool = self.__installation.tools[name]
self.__installation.tools.remove(tool)
except KeyError:
pass
logger.info('unloaded source: %s', source.name)
def __parse_blueprint(self, source: Source, fn: str, d: dict) -> BuildInstructions:
return BuildInstructions(root=os.path.dirname(fn),
tag=d['tag'],
context=d.get('context', '.'),
filename=d.get('file', 'Dockerfile'),
arguments=d.get('arguments', {}),
source=source.name,
build_stage=d.get('build-stage', None),
depends_on=d.get('depends-on', None))
def __parse_bug(self, source: Source, fn: str, d: dict) -> Bug:
d_ = d.copy()
d_['dataset'] = d.get('dataset', None)
d_['program'] = d.get('program', None)
d_['source'] = source.name
return Bug.from_dict(d_)
def __parse_tool(self, source: Source, fn: str, d: dict) -> Tool:
return Tool(d['name'],
d['image'],
d.get('environment', {}),
source.name)
def __parse_file(self,
source: Source,
fn: str,
bugs: List[Bug],
blueprints: List[BuildInstructions],
tools: List[Tool]
) -> None:
with open(fn, 'r') as f:
yml = yaml.safe_load(f)
# TODO check version
if 'version' not in yml:
logger.warning("no version specified in manifest file: %s", fn)
for description in yml.get('bugs', []):
logger.debug("parsing bug: %s", json.dumps(description))
try:
bug = self.__parse_bug(source, fn, description)
logger.debug("parsed bug: %s", bug.name)
bugs.append(bug)
except KeyError as e:
logger.exception("missing property in bug description: %s",
str(e))
for description in yml.get('blueprints', []):
logger.debug("parsing blueprint: %s", json.dumps(description))
try:
blueprint = self.__parse_blueprint(source, fn, description)
logger.debug("parsed blueprint for image: %s",
blueprint.name)
blueprints.append(blueprint)
except KeyError as e:
logger.exception("missing property in blueprint description: %s",
str(e))
for description in yml.get('tools', []):
logger.debug("parsing tool: %s", json.dumps(description))
try:
tool = self.__parse_tool(source, fn, description)
logger.debug("parsed tool: %s", tool.name)
tools.append(tool)
except KeyError as e:
logger.exception("missing property in tool description: %s",
str(e))
def load(self, source: Source) -> None:
"""
Attempts to load all resources (i.e., bugs, tools, and blueprints)
provided by a given source. If the given source has already been
loaded, then that resources for that source are unloaded and
reloaded.
"""
logger.info('loading source %s at %s', source.name, source.location)
if source.name in self.__sources:
self.unload(source)
bugs = []
blueprints = []
tools = []
# find and parse all bugzoo files
glob_pattern = '{}/**/*.bugzoo.y*ml'.format(source.location)
for fn in glob.iglob(glob_pattern, recursive=True):
if fn.endswith('.yml') or fn.endswith('.yaml'):
logger.debug('found manifest file: %s', fn)
self.__parse_file(source, fn, bugs, blueprints, tools)
logger.debug('parsed manifest file: %s', fn)
# register contents
for bug in bugs:
self.__installation.bugs.add(bug)
for blueprint in blueprints:
self.__installation.build.add(blueprint)
for tool in tools:
self.__installation.tools.add(tool)
# record contents of source
contents = SourceContents([b.name for b in blueprints],
[b.name for b in bugs],
[t.name for t in tools])
self.__sources[source.name] = source
self.__contents[source.name] = contents
logger.info("loaded source: %s", source.name)
def contents(self, source: Source) -> SourceContents:
"""
Returns a summary of the bugs, tools, and blueprints provided by a
given source.
"""
return self.__contents[source.name]
[docs] def add(self, name: str, path_or_url: str) -> Source:
"""
Attempts to register a source provided by a given URL or local path
under a given name.
Returns:
a description of the registered source.
Raises:
NameInUseError: if an existing source is already registered under
the given name.
IOError: if no directory exists at the given path.
IOError: if downloading the remote source failed. (FIXME)
"""
logger.info("adding source: %s -> %s", name, path_or_url)
if name in self.__sources:
logger.info("name already used by existing source: %s", name)
raise NameInUseError(name)
is_url = False
try:
scheme = urllib.parse.urlparse(path_or_url).scheme
is_url = scheme in ['http', 'https']
logger.debug("source determined to be remote: %s", path_or_url)
except ValueError:
logger.debug("source determined to be local: %s", path_or_url)
if is_url:
url = path_or_url
# convert url to a local path
path = url.replace('https://', '')
path = path.replace('/', '_')
path = path.replace('.', '_')
path = os.path.join(self.__path, path)
# download from remote to local
shutil.rmtree(path, ignore_errors=True)
try:
# TODO shallow clone
logger.debug("cloning repository %s to %s", url, path)
repo = git.Repo.clone_from(url, path)
logger.debug("cloned repository %s to %s", url, path)
sha = repo.head.object.hexsha
version = repo.git.rev_parse(sha, short=8)
except: # TODO determine error type
shutil.rmtree(path, ignore_errors=True)
logger.error("failed to download remote source to local: %s -> %s", url, path)
raise IOError("failed to download remote source to local installation: '{}' -> '{}'".format(url, path))
source = RemoteSource(name, path, url, version)
else:
path = os.path.abspath(path_or_url)
if not os.path.isdir(path):
raise IOError("no directory found at path: {}".format(path))
source = LocalSource(name, path)
self.load(source)
self.save()
logger.info('added source: %s', name)
[docs] def remove(self, source: Source) -> None:
"""
Unregisters a given source with this server. If the given source is a
remote source, then its local copy will be removed from disk.
Raises:
KeyError: if the given source is not registered with this server.
"""
self.unload(source)
if isinstance(source, RemoteSource):
shutil.rmtree(source.location, ignore_errors=True)
self.save()