Source code for uvhttp.utils

import asyncio
import functools
import multiprocessing
import os
import socket
import ssl
from json import loads
from sanic import Sanic
from sanic.response import json

NUM_WORKERS = int(os.getenv("GOMAXPROCS", multiprocessing.cpu_count() * 2))

def start_loop(func):
    @functools.wraps(func)
    def new_func(*args, **kwargs):
        loop = asyncio.new_event_loop()
        loop.run_until_complete(func(loop))

    return new_func

def run_workers(func):
    procs = []

    if NUM_WORKERS > 1:
        for _ in range(NUM_WORKERS):
            proc = multiprocessing.Process(target=func)
            proc.start()
            procs.append(proc)

        for proc in procs:
            proc.join()
    else:
        func()

def is_ip(host):
    """
    Return True if ``host`` is an IPv4 or IPv6 address.
    """
    if isinstance(host, bytes):
        host = host.decode()

    try:
        socket.inet_pton(socket.AF_INET, host)
        return True
    except socket.error:
        pass

    try:
        socket.inet_pton(socket.AF_INET6, host)
        return True
    except socket.error:
        return False

class HeaderDict:
    def __init__(self, original):
        self.__dict = dict([ (k.upper(), (k, v)) for k, v in original.items() ])

    def __getitem__(self, key):
        try:
            return self.__dict[key.upper()][1]
        except KeyError:
            return b''

    def __iter__(self):
        for value in self.__dict.values():
            yield value[0]

    def keys(self):
        return [ key for key in self ]

    def items(self):
        for key, value in self.__dict.items:
            yield value

[docs]class HttpServer: """ An HTTP server that uses Sanic in the backend to help write unit tests. """ def __init__(self, host=None, port=None, https_host=None, https_port=None): self.app = Sanic(__name__) self.app.config.LOGO = None self.host = host or '127.0.0.1' self.port = port or 8089 self.https_host = https_host or '127.0.0.1' self.https_port = https_port or 8090 self.add_routes() @property def url(self): """ Return the URL of the started server. """ return 'http://{}:{}/'.format(self.host, self.port).encode() @property def https_url(self): """ Return the HTTPS URL of the started server. """ return 'https://{}:{}/'.format(self.https_host, self.https_port).encode()
[docs] async def start(self): """ Start the server. """ self.server = await self.app.create_server(host=self.host, port=self.port) pem = os.path.join(os.path.dirname(__file__), 'example.pem') ctx = ssl.create_default_context(purpose=ssl.Purpose.CLIENT_AUTH) ctx.load_cert_chain(pem) self.https_server = await self.app.create_server(host=self.https_host, port=self.https_port, ssl=ctx)
[docs] def stop(self): """ Stop the server """ self.server.close() self.https_server.close()
[docs] def add_routes(self): """ A function for adding all of the routes you need. Implement this in your inherited class to add your endpoints. """ self.app.add_route(self.echo, "echo", [ 'GET', 'POST' ])
[docs] async def echo(self, request): """ An echo endpoint that returns all of the data about the request. """ try: parsed_json = loads(request.body) except (ValueError, TypeError): parsed_json = None return json({ "form": request.form, "body": request.body, "args": request.args, "url": request.url, "query": request.query_string, "json": parsed_json, "headers": request.headers })
def http_server(http_server_cls, *server_args, **server_kwargs): """ Start a :class:`.HttpServer` and an event loop for use in a test case. """ def http_server_wrapper(func): @functools.wraps(func) @start_loop async def real_http_server_wrapper(loop, *args, **kwargs): server = http_server_cls(*server_args, **server_kwargs) await server.start() try: await func(server, loop, *args, **kwargs) finally: server.stop() return real_http_server_wrapper return http_server_wrapper def http_server_no_loop(http_server_cls, *server_args, **server_kwargs): """ Start a :class:`.HttpServer` for use in a test case. """ def http_server_wrapper(func): @functools.wraps(func) async def real_http_server_wrapper(*args, **kwargs): server = http_server_cls(*server_args, **server_kwargs) await server.start() try: await func(server, *args, **kwargs) finally: server.stop() return real_http_server_wrapper return http_server_wrapper