# -*- coding: utf-8 -*-
"""
proxy.py
~~~~~~~~
⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
Network monitoring, controls & Application development, testing, debugging.
:copyright: (c) 2013-present by Abhinav Singh and contributors.
:license: BSD, see LICENSE for more details.
"""
import os
import asyncio
import logging
import argparse
import selectors
import multiprocessing
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING, Any, Set, Dict, List, Tuple, Generic, TypeVar, Optional,
cast,
)
from ...common.types import Readables, Writables, SelectableEvents
from ...common.logger import Logger
from ...common.constants import (
DEFAULT_WAIT_FOR_TASKS_TIMEOUT, DEFAULT_SELECTOR_SELECT_TIMEOUT,
DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT,
)
if TYPE_CHECKING: # pragma: no cover
from .work import Work
from ..event import EventQueue
T = TypeVar('T')
logger = logging.getLogger(__name__)
[docs]class Threadless(ABC, Generic[T]):
"""Work executor base class.
Threadless provides an event loop, which is shared across
multiple :class:`~proxy.core.acceptor.work.Work` instances to handle
work.
Threadless takes input a `work_klass` and an `event_queue`. `work_klass`
must conform to the :class:`~proxy.core.acceptor.work.Work`
protocol. Work is received over the `event_queue`.
When a work is accepted, threadless creates a new instance of `work_klass`.
Threadless will then invoke necessary lifecycle of the
:class:`~proxy.core.acceptor.work.Work` protocol,
allowing `work_klass` implementation to handle the assigned work.
Example, :class:`~proxy.core.base.tcp_server.BaseTcpServerHandler`
implements :class:`~proxy.core.acceptor.work.Work` protocol. It
expects a client connection as work payload and hooks into the
threadless event loop to handle the client connection.
"""
def __init__(
self,
iid: str,
work_queue: T,
flags: argparse.Namespace,
event_queue: Optional['EventQueue'] = None,
) -> None:
super().__init__()
self.iid = iid
self.work_queue = work_queue
self.flags = flags
self.event_queue = event_queue
self.running = multiprocessing.Event()
self.works: Dict[int, 'Work[Any]'] = {}
self.selector: Optional[selectors.DefaultSelector] = None
# If we remove single quotes for typing hint below,
# runtime exceptions will occur for < Python 3.9.
#
# Ref https://github.com/abhinavsingh/proxy.py/runs/4279055360?check_suite_focus=true
self.unfinished: Set['asyncio.Task[bool]'] = set()
self.registered_events_by_work_ids: Dict[
# work_id
int,
# fileno, mask
SelectableEvents,
] = {}
self.wait_timeout: float = DEFAULT_WAIT_FOR_TASKS_TIMEOUT
self.cleanup_inactive_timeout: float = DEFAULT_INACTIVE_CONN_CLEANUP_TIMEOUT
self._total: int = 0
# When put at the top, causes circular import error
# since integrated ssh tunnel was introduced.
from ..connection import ( # pylint: disable=C0415
UpstreamConnectionPool,
)
self._upstream_conn_pool: Optional['UpstreamConnectionPool'] = None
self._upstream_conn_filenos: Set[int] = set()
if self.flags.enable_conn_pool:
self._upstream_conn_pool = UpstreamConnectionPool()
@property
@abstractmethod
def loop(self) -> Optional[asyncio.AbstractEventLoop]:
raise NotImplementedError()
[docs] @abstractmethod
def receive_from_work_queue(self) -> bool:
"""Work queue is ready to receive new work.
Receive it and call ``work_on_tcp_conn``.
Return True to tear down the loop."""
raise NotImplementedError()
[docs] @abstractmethod
def work_queue_fileno(self) -> Optional[int]:
"""If work queue must be selected before calling
``receive_from_work_queue`` then implementation must
return work queue fd."""
raise NotImplementedError()
[docs] @abstractmethod
def work(self, *args: Any) -> None:
raise NotImplementedError()
[docs] def create(self, uid: str, *args: Any) -> 'Work[T]':
return cast(
'Work[T]', self.flags.work_klass(
self.flags.work_klass.create(*args),
flags=self.flags,
event_queue=self.event_queue,
uid=uid,
upstream_conn_pool=self._upstream_conn_pool,
),
)
[docs] def close_work_queue(self) -> None:
"""Only called if ``work_queue_fileno`` returns an integer.
If an fd is select-able for work queue, make sure
to close the work queue fd now."""
pass # pragma: no cover
[docs] async def _update_work_events(self, work_id: int) -> None:
assert self.selector is not None
worker_events = await self.works[work_id].get_events()
# NOTE: Current assumption is that multiple works will not
# be interested in the same fd. Descriptors of interests
# returned by work must be unique.
#
# TODO: Ideally we must diff and unregister socks not
# returned of interest within current _select_events call
# but exists in the registered_socks_by_work_ids registry.
for fileno in worker_events:
if work_id not in self.registered_events_by_work_ids:
self.registered_events_by_work_ids[work_id] = {}
mask = worker_events[fileno]
if fileno in self.registered_events_by_work_ids[work_id]:
oldmask = self.registered_events_by_work_ids[work_id][fileno]
if mask != oldmask:
self.selector.modify(
fileno, events=mask,
data=work_id,
)
self.registered_events_by_work_ids[work_id][fileno] = mask
logger.debug(
'fd#{0} modified for mask#{1} by work#{2}'.format(
fileno, mask, work_id,
),
)
# else:
# logger.info(
# 'fd#{0} by work#{1} not modified'.format(fileno, work_id))
elif fileno in self._upstream_conn_filenos:
# Descriptor offered by work, but is already registered by connection pool
# Most likely because work has acquired a reusable connection.
self.selector.modify(fileno, events=mask, data=work_id)
self.registered_events_by_work_ids[work_id][fileno] = mask
self._upstream_conn_filenos.remove(fileno)
logger.debug(
'fd#{0} borrowed with mask#{1} by work#{2}'.format(
fileno, mask, work_id,
),
)
# Can throw ValueError: Invalid file descriptor: -1
#
# A guard within Work classes may not help here due to
# asynchronous nature. Hence, threadless will handle
# ValueError exceptions raised by selector.register
# for invalid fd.
#
# TODO: Also remove offending work from pool to avoid spin loop.
elif fileno != -1:
self.selector.register(fileno, events=mask, data=work_id)
self.registered_events_by_work_ids[work_id][fileno] = mask
logger.debug(
'fd#{0} registered for mask#{1} by work#{2}'.format(
fileno, mask, work_id,
),
)
[docs] async def _update_conn_pool_events(self) -> None:
if not self._upstream_conn_pool:
return
assert self.selector is not None
new_conn_pool_events = await self._upstream_conn_pool.get_events()
old_conn_pool_filenos = self._upstream_conn_filenos.copy()
self._upstream_conn_filenos.clear()
new_conn_pool_filenos = set(new_conn_pool_events.keys())
new_conn_pool_filenos.difference_update(old_conn_pool_filenos)
for fileno in new_conn_pool_filenos:
self.selector.register(
fileno,
events=new_conn_pool_events[fileno],
data=0,
)
self._upstream_conn_filenos.add(fileno)
old_conn_pool_filenos.difference_update(self._upstream_conn_filenos)
for fileno in old_conn_pool_filenos:
self.selector.unregister(fileno)
[docs] async def _update_selector(self) -> None:
assert self.selector is not None
unfinished_work_ids = set()
for task in self.unfinished:
unfinished_work_ids.add(task._work_id) # type: ignore
for work_id in self.works:
# We don't want to invoke work objects which haven't
# yet finished their previous task
if work_id in unfinished_work_ids:
continue
await self._update_work_events(work_id)
await self._update_conn_pool_events()
[docs] async def _selected_events(self) -> Tuple[
Dict[int, Tuple[Readables, Writables]],
bool,
]:
"""For each work, collects events that they are interested in.
Calls select for events of interest.
Returns a 2-tuple containing a dictionary and boolean.
Dictionary keys are work IDs and values are 2-tuple
containing ready readables & writables.
Returned boolean value indicates whether there is
a newly accepted work waiting to be received and
queued for processing. This is only applicable when
:class:`~proxy.core.work.threadless.Threadless.work_queue_fileno`
returns a valid fd.
"""
assert self.selector is not None
await self._update_selector()
# Keys are work_id and values are 2-tuple indicating
# readables & writables that work_id is interested in
# and are ready for IO.
work_by_ids: Dict[int, Tuple[Readables, Writables]] = {}
new_work_available = False
wqfileno = self.work_queue_fileno()
if wqfileno is None:
# When ``work_queue_fileno`` returns None,
# always return True for the boolean value.
new_work_available = True
events = self.selector.select(
timeout=DEFAULT_SELECTOR_SELECT_TIMEOUT,
)
for key, mask in events:
if not new_work_available and wqfileno is not None and key.fileobj == wqfileno:
assert mask & selectors.EVENT_READ
new_work_available = True
continue
if key.data not in work_by_ids:
work_by_ids[key.data] = ([], [])
if mask & selectors.EVENT_READ:
work_by_ids[key.data][0].append(key.fd)
if mask & selectors.EVENT_WRITE:
work_by_ids[key.data][1].append(key.fd)
return (work_by_ids, new_work_available)
[docs] async def _wait_for_tasks(self) -> Set['asyncio.Task[bool]']:
finished, self.unfinished = await asyncio.wait(
self.unfinished,
timeout=self.wait_timeout,
return_when=asyncio.FIRST_COMPLETED,
)
return finished # noqa: WPS331
[docs] def _cleanup_inactive(self) -> None:
inactive_works: List[int] = []
for work_id in self.works:
if self.works[work_id].is_inactive():
inactive_works.append(work_id)
for work_id in inactive_works:
self._cleanup(work_id)
# TODO: HttpProtocolHandler.shutdown can call flush which may block
[docs] def _cleanup(self, work_id: int) -> None:
if work_id in self.registered_events_by_work_ids:
assert self.selector
for fileno in self.registered_events_by_work_ids[work_id]:
logger.debug(
'fd#{0} unregistered by work#{1}'.format(
fileno, work_id,
),
)
self.selector.unregister(fileno)
self.registered_events_by_work_ids[work_id].clear()
del self.registered_events_by_work_ids[work_id]
self.works[work_id].shutdown()
del self.works[work_id]
if self.work_queue_fileno() is not None:
os.close(work_id)
[docs] def _create_tasks(
self,
work_by_ids: Dict[int, Tuple[Readables, Writables]],
) -> Set['asyncio.Task[bool]']:
assert self.loop
tasks: Set['asyncio.Task[bool]'] = set()
for work_id in work_by_ids:
if work_id == 0:
assert self._upstream_conn_pool
task = self.loop.create_task(
self._upstream_conn_pool.handle_events(
*work_by_ids[work_id],
),
)
else:
task = self.loop.create_task(
self.works[work_id].handle_events(*work_by_ids[work_id]),
)
task._work_id = work_id # type: ignore[attr-defined]
# task.set_name(work_id)
tasks.add(task)
return tasks
[docs] async def _run_once(self) -> bool:
assert self.loop is not None
work_by_ids, new_work_available = await self._selected_events()
# Accept new work if available
#
# TODO: We must use a work klass to handle
# client_queue fd itself a.k.a. accept_client
# will become handle_readables.
if new_work_available:
teardown = self.receive_from_work_queue()
if teardown:
return teardown
if len(work_by_ids) == 0:
return False
# Invoke Threadless.handle_events
self.unfinished.update(self._create_tasks(work_by_ids))
# logger.debug('Executing {0} works'.format(len(self.unfinished)))
# Cleanup finished tasks
for task in await self._wait_for_tasks():
# Checking for result can raise exception e.g.
# CancelledError, InvalidStateError or an exception
# from underlying task e.g. TimeoutError.
teardown = False
work_id = task._work_id # type: ignore
try:
teardown = task.result()
finally:
if teardown:
self._cleanup(work_id)
# self.cleanup(int(task.get_name()))
# logger.debug(
# 'Done executing works, {0} pending, {1} registered'.format(
# len(self.unfinished), len(self.registered_events_by_work_ids),
# ),
# )
return False
[docs] async def _run_forever(self) -> None:
tick = 0
try:
while True:
if await self._run_once():
break
# Check for inactive and shutdown signal
elapsed = tick * \
(DEFAULT_SELECTOR_SELECT_TIMEOUT + self.wait_timeout)
if elapsed >= self.cleanup_inactive_timeout:
self._cleanup_inactive()
if self.running.is_set():
break
tick = 0
tick += 1
except KeyboardInterrupt:
pass
finally:
if self.loop:
self.loop.stop()
[docs] def run(self) -> None:
Logger.setup(
self.flags.log_file, self.flags.log_level,
self.flags.log_format,
)
wqfileno = self.work_queue_fileno()
try:
self.selector = selectors.DefaultSelector()
if wqfileno is not None:
self.selector.register(
wqfileno,
selectors.EVENT_READ,
data=wqfileno,
)
assert self.loop
logger.debug('Working on {0} works'.format(len(self.works)))
self.loop.create_task(self._run_forever())
self.loop.run_forever()
except KeyboardInterrupt:
pass
finally:
assert self.selector is not None
if wqfileno is not None:
self.selector.unregister(wqfileno)
self.close_work_queue()
self.selector.close()
assert self.loop is not None
self.loop.run_until_complete(self.loop.shutdown_asyncgens())
self.loop.close()