-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtcp_tunnel.py
147 lines (118 loc) · 5.54 KB
/
tcp_tunnel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import asyncio
import argparse
import struct
import json
from functools import partial
logger = None
class TcpTunnel(object):
HEADER_SIZE = 4
def __init__(self) -> None:
self.port_mapping = dict()
self.servers = []
def add_port_mapping(self, local_port, server_host, server_port, remote_host, remote_port):
self.port_mapping[local_port] = (server_host, server_port, remote_host, remote_port)
def add_server(self, server_port, server_host='0.0.0.0'):
self.servers.append((server_host, server_port))
def start(self):
async def _start_acceptors():
acceptors = []
for server in self.servers:
acceptors.append(asyncio.create_task(self._create_server(server[0], server[1], self.server_handle)))
for local_port, tunnel_info in self.port_mapping.items():
acceptors.append(asyncio.create_task(self._create_server('0.0.0.0', local_port, partial(self.local_handle, tunnel_info))))
tasks = asyncio.gather(*acceptors, return_exceptions=True)
await tasks
TcpTunnel._run(_start_acceptors())
async def server_handle(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
data = await reader.read(self.HEADER_SIZE)
if len(data) != self.HEADER_SIZE:
writer.close()
return
size = struct.unpack('!I', data[:self.HEADER_SIZE])[0]
data = await reader.read(size)
while len(data) < size and not reader.at_eof():
data += await reader.read(size - len(data))
tunnel_req = json.loads(data.decode())
dst_reader, dst_writer = await asyncio.open_connection(tunnel_req['dst_host'], tunnel_req['dst_port'])
conn_msg = f"{reader._transport._extra['peername']} <==> {dst_reader._transport._extra['sockname']} <==> {dst_reader._transport._extra['peername']}'"
logger.info(f"[server] [connect] {conn_msg}")
await asyncio.wait([self.trans(dst_reader, writer), self.trans(reader, dst_writer)])
logger.info(f"[server] [disconnect] {conn_msg}")
async def local_handle(self, tunnel_info, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
server_host, server_port, remote_host, remote_port = tunnel_info
remote_reader, remote_writer = await asyncio.open_connection(server_host, server_port)
tunnel_req = json.dumps({'dst_host': remote_host, 'dst_port': remote_port}).encode()
remote_writer.write(struct.pack('!I', len(tunnel_req)))
remote_writer.write(tunnel_req)
conn_msg = f"{reader._transport._extra['peername']} <==> ('{server_host}', {server_port}) <==> ('{remote_host}', {remote_port})"
logger.info(f"[client] [connect] {conn_msg}")
await asyncio.wait([self.trans(reader, remote_writer), self.trans(remote_reader, writer)])
logger.info(f"[client] [disconnect] {conn_msg}")
async def trans(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
while not reader.at_eof() and not writer.is_closing():
data = await reader.read(4096)
writer.write(data)
if not reader.at_eof():
reader.feed_eof()
if not writer.is_closing():
writer.close()
@staticmethod
async def _create_server(host, port, handle):
server = await asyncio.start_server(handle, host, port)
async with server:
await server.serve_forever()
@staticmethod
def _run(coroutine):
try:
asyncio.run(coroutine)
except KeyboardInterrupt:
pass
finally:
pass
def parse_server_addr(s):
fields = s.split(':')
if len(fields) == 1:
return '0.0.0.0', int(fields[0])
else:
return fields[0], int(fields[1])
def parse_remote(s):
host, port = s.split(':')
return host, int(port)
def parse_args():
parser = argparse.ArgumentParser(description='tcp tunnel')
parser.add_argument('-t', '--tunnel', action='append', required=False)
parser.add_argument('-s', '--server', type=parse_server_addr, default=None, required=False)
parser.add_argument('-r', '--remote', type=parse_remote, default=(None, None), required=False)
args = parser.parse_args()
if args.tunnel is not None:
def parse_tunnel_conf(conf, server_host, server_port):
fields = conf.split(':')
if len(fields) == 3:
return int(fields[0]), server_host, server_port, fields[1], int(fields[2])
return int(fields[0]), fields[1], int(fields[2]), fields[3], int(fields[4])
for i in range(len(args.tunnel)):
args.tunnel[i] = parse_tunnel_conf(args.tunnel[i], args.remote[0], args.remote[1])
return args
def get_logger():
import logging
import logging.handlers
logger = logging.getLogger("tcp_tunnel_logger")
logger.setLevel(logging.INFO)
channel = logging.handlers.RotatingFileHandler(filename="tcp_tunnel.log", maxBytes=100*1024*1024, backupCount=2)
formatter = logging.Formatter("%(asctime)s|%(filename)s:%(lineno)s|%(levelname)s|%(message)s")
channel.setFormatter(formatter)
logger.addHandler(channel)
return logger
def main():
global logger
logger = get_logger()
tunnel = TcpTunnel()
args = parse_args()
if args.server is not None:
tunnel.add_server(server_port=args.server[1], server_host=args.server[0])
if args.tunnel is not None:
for tun in args.tunnel:
tunnel.add_port_mapping(*tun)
tunnel.start()
if __name__ == '__main__':
main()