# Copyright 2019 Eli Lilly and Company
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta, abstractmethod
from enum import Enum
import functools
from pathlib import Path
from typing import Optional, Sequence
from urllib.request import BaseHandler, Request, build_opener, install_opener
from pkg_resources import iter_entry_points
from pytest_wdl.plugins import PluginError, PluginFactory
from pytest_wdl.utils import LOG, verify_digests
try:
from tqdm import tqdm as progress
except ImportError: # pragma: no-cover
LOG.debug(
"tqdm is not installed; progress bar will not be displayed when "
"downloading files"
)
progress = None
[docs]class Method(Enum):
OPEN = ("urlopen", "{}_open")
REQUEST = ("request", "{}_request")
RESPONSE = ("response", "{}_response")
def __init__(self, src_attr, dest_pattern):
self.src_attr = src_attr
self.dest_pattern = dest_pattern
[docs]class Response(metaclass=ABCMeta):
[docs] @abstractmethod
def download_file(
self,
destination: Path,
show_progress: bool = False,
digests: Optional[dict] = None
):
"""
Download a file to a specific destination.
Args:
destination: Destination path
show_progress: Whether to show a progress bar
digests: Optional dict mapping hash names to digests. These are used to
validate the downloaded file.
Raises:
DigestsNotEqualError
"""
pass
[docs]class BaseResponse(Response, metaclass=ABCMeta):
[docs] @abstractmethod
def get_content_length(self) -> Optional[int]:
pass
[docs] @abstractmethod
def read(self, block_size: int):
pass
[docs] def download_file(
self,
destination: Path,
show_progress: bool = False,
digests: Optional[dict] = None
):
total_size = self.get_content_length()
block_size = 16 * 1024
if total_size and total_size < block_size:
block_size = total_size
if show_progress and progress:
progress_bar = progress(
total=total_size,
unit="b",
unit_scale=True,
unit_divisor=1024,
desc=f"Localizing {destination.name}"
)
def progress_reader():
b = self.read(block_size)
if b:
progress_bar.update(block_size)
else:
progress_bar.close()
return b
reader = progress_reader
else:
reader = functools.partial(self.read, block_size)
downloaded_size = 0
with open(destination, "wb") as out:
while True:
buf = reader()
if not buf:
break
downloaded_size += len(buf)
out.write(buf)
if downloaded_size != total_size: # TODO: test this
raise AssertionError(
f"Size of downloaded file {destination} does not match expected size "
f"{total_size}"
)
if digests:
verify_digests(destination, digests)
[docs]class ResponseWrapper(BaseResponse):
def __init__(self, rsp):
self.rsp = rsp
[docs] def get_content_length(self) -> Optional[int]:
size_str = self.rsp.getheader("content-length")
if size_str:
return int(size_str)
[docs] def read(self, block_size: int) -> bytes:
return self.rsp.read(block_size)
[docs]class UrlHandler(BaseHandler, metaclass=ABCMeta):
@property
@abstractmethod
def scheme(self) -> str:
pass
@property
def handles(self) -> Sequence[Method]:
return [] # pragma: no-cover
[docs] def alias(self):
"""
Add aliases that are required by urllib for handled methods.
"""
for method in self.handles:
src = getattr(self, method.src_attr)
setattr(self, method.dest_pattern.format(self.scheme), src)
[docs] def request(self, request: Request) -> Request:
pass
[docs] def urlopen(self, request: Request) -> Response:
pass
[docs] def response(self, request: Request, response: Response) -> Response:
pass
[docs]def install_schemes():
def create_handler(_entry_point):
handler_factory = PluginFactory(_entry_point, UrlHandler)
handler = handler_factory()
handler.alias()
return handler
handlers = []
for entry_point in iter_entry_points(group="pytest_wdl.url_schemes"):
try:
handlers.append(create_handler(entry_point))
except PluginError:
pass
install_opener(build_opener(*handlers))