Source code for proxy.core.ssh.listener

# -*- coding: utf-8 -*-
"""
    proxy.py
    ~~~~~~~~
    ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on
    Network monitoring, controls & Application development, testing, debugging.

    :copyright: (c) 2013-present by Abhinav Singh and contributors.
    :license: BSD, see LICENSE for more details.
"""
import logging
import argparse
from typing import TYPE_CHECKING, Any, Set, Callable, Optional


try:
    from paramiko import SSHClient, AutoAddPolicy
    from paramiko.transport import Transport
    if TYPE_CHECKING:   # pragma: no cover
        from paramiko.channel import Channel

        from ...common.types import HostPort
except ImportError:     # pragma: no cover
    pass

from ...common.flag import flags


logger = logging.getLogger(__name__)


flags.add_argument(
    '--tunnel-hostname',
    type=str,
    default=None,
    help='Default: None. Remote hostname or IP address to which SSH tunnel will be established.',
)

flags.add_argument(
    '--tunnel-port',
    type=int,
    default=22,
    help='Default: 22. SSH port of the remote host.',
)

flags.add_argument(
    '--tunnel-username',
    type=str,
    default=None,
    help='Default: None. Username to use for establishing SSH tunnel.',
)

flags.add_argument(
    '--tunnel-ssh-key',
    type=str,
    default=None,
    help='Default: None. Private key path in pem format',
)

flags.add_argument(
    '--tunnel-ssh-key-passphrase',
    type=str,
    default=None,
    help='Default: None. Private key passphrase',
)

flags.add_argument(
    '--tunnel-remote-port',
    type=int,
    default=8899,
    help='Default: 8899. Remote port which will be forwarded locally for proxy.',
)


[docs]class SshTunnelListener: """Connects over SSH and forwards a remote port to local host. Incoming connections are delegated to provided callback.""" def __init__( self, flags: argparse.Namespace, on_connection_callback: Callable[['Channel', 'HostPort', 'HostPort'], None], ) -> None: self.flags = flags self.on_connection_callback = on_connection_callback self.ssh: Optional[SSHClient] = None self.transport: Optional[Transport] = None self.forwarded: Set['HostPort'] = set()
[docs] def start_port_forward(self, remote_addr: 'HostPort') -> None: assert self.transport is not None self.transport.request_port_forward( *remote_addr, handler=self.on_connection_callback, ) self.forwarded.add(remote_addr) logger.info('%s:%d forwarding successful...' % remote_addr)
[docs] def stop_port_forward(self, remote_addr: 'HostPort') -> None: assert self.transport is not None self.transport.cancel_port_forward(*remote_addr) self.forwarded.remove(remote_addr)
def __enter__(self) -> 'SshTunnelListener': self.setup() return self def __exit__(self, *args: Any) -> None: self.shutdown()
[docs] def setup(self) -> None: self.ssh = SSHClient() self.ssh.load_system_host_keys() self.ssh.set_missing_host_key_policy(AutoAddPolicy()) self.ssh.connect( hostname=self.flags.tunnel_hostname, port=self.flags.tunnel_port, username=self.flags.tunnel_username, key_filename=self.flags.tunnel_ssh_key, passphrase=self.flags.tunnel_ssh_key_passphrase, ) logger.info( 'SSH connection established to %s:%d...' % ( self.flags.tunnel_hostname, self.flags.tunnel_port, ), ) self.transport = self.ssh.get_transport()
[docs] def shutdown(self) -> None: for remote_addr in list(self.forwarded): self.stop_port_forward(remote_addr) self.forwarded.clear() if self.transport is not None: self.transport.close() if self.ssh is not None: self.ssh.close()