summaryrefslogtreecommitdiffstats
path: root/misc/tools/tftpproxy.py
diff options
context:
space:
mode:
authorChris Johns <chrisj@rtems.org>2019-08-19 14:29:03 +1000
committerChris Johns <chrisj@rtems.org>2019-08-19 16:44:11 +1000
commitdeb54b61457c831cac8881fb63fd8dbb747bc3ea (patch)
tree5da5224f3f7ab91b2117c652bfea2a0d41e262e5 /misc/tools/tftpproxy.py
parenttester: Add raspberrypi2 BSP. (diff)
downloadrtems-tools-deb54b61457c831cac8881fb63fd8dbb747bc3ea.tar.bz2
misc/tftpproxy: Add a proxy TFTP server.
- Uses a config INI file to map clients to servers - Handle a number of requests to a single server's TFTP port (69) and multiplex to a non-su ports or different servers. - Supports running rtems-test to more than one hardware device using TFTP at once.
Diffstat (limited to 'misc/tools/tftpproxy.py')
-rw-r--r--misc/tools/tftpproxy.py423
1 files changed, 423 insertions, 0 deletions
diff --git a/misc/tools/tftpproxy.py b/misc/tools/tftpproxy.py
new file mode 100644
index 0000000..a815584
--- /dev/null
+++ b/misc/tools/tftpproxy.py
@@ -0,0 +1,423 @@
+#
+# Copyright 2019 Chris Johns (chris@contemporary.software)
+# All rights reserved.
+#
+# Permission to use, copy, modify, and/or distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
+# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
+# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
+# OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+#
+# The TFTP proxy redirects a TFTP session to another host. If you have a
+# farm of boards you can configure them to point to this proxy and it will
+# redirect the requests to another machine that is testing it.
+#
+
+from __future__ import print_function
+
+import argparse
+import os
+import socket
+import sys
+import time
+import threading
+
+try:
+ import socketserver
+except:
+ import SocketServer as socketserver
+
+from rtemstoolkit import configuration
+from rtemstoolkit import error
+from rtemstoolkit import log
+from rtemstoolkit import version
+
+import getmac
+
+def host_port_split(ip_port):
+ ips = ip_port.split(':')
+ port = 0
+ if len(ips) >= 1:
+ ip = ips[0]
+ if len(ips) == 2:
+ port = int(ips[1])
+ else:
+ raise error.general('invalid host:port: %s' % (ip_port))
+ return ip, port
+
+class tftp_session(object):
+
+ opcodes = ['nul', 'RRQ', 'WRQ', 'DATA', 'ACK', 'ERROR', 'OACK']
+
+ def __init__(self):
+ self.packets = []
+ self.block = 0
+ self.block_size = 512
+ self.timeout = 0
+ self.finished = True
+
+ def __str__(self):
+ return os.linesep.join([self.decode(p[0], p[1], p[2]) for p in self.packets])
+
+ def data(self, host, port, data):
+ finished = False
+ self.packets += [(host, port, data)]
+ opcode = (data[0] << 8) | data[1]
+ if opcode == 1 or opcode == 2:
+ self.block = 0
+ self.finished = False
+ value = self.get_option('timeout', data)
+ if value is not None:
+ self.timeout = int(value)
+ value = self.get_option('blksize', data)
+ if value is not None:
+ self.block_size = int(value)
+ else:
+ self.block_size = 512
+ elif opcode == 3:
+ self.block = (data[2] << 8) | data[3]
+ if len(data) - 4 < self.block_size:
+ self.finished = True
+ elif opcode == 4:
+ self.block = (data[2] << 8) | data[3]
+ if self.finished:
+ finished = True
+ return finished
+
+ def decode(self, host, port, data):
+ s = ''
+ dlen = len(data)
+ if dlen > 2:
+ opcode = (data[0] << 8) | data[1]
+ if opcode < len(self.opcodes):
+ if opcode == 1 or opcode == 2:
+ s += ' ' + self.opcodes[opcode] + ', '
+ i = 2
+ while data[i] != 0:
+ s += chr(data[i])
+ i += 1
+ while i < dlen - 1:
+ s += ', '
+ i += 1
+ while data[i] != 0:
+ s += chr(data[i])
+ i += 1
+ elif opcode == 3:
+ block = (data[2] << 8) | data[3]
+ s += ' ' + self.opcodes[opcode] + ', '
+ s += '#' + str(block) + ', '
+ if dlen > 4:
+ s += '%02x%02x..%02x%02x' % (data[4], data[5], data[-2], data[-1])
+ else:
+ s += '%02x%02x%02x%02x' % (data[4], data[5], data[6], data[6])
+ s += ',' + str(dlen - 4)
+ elif opcode == 4:
+ block = (data[2] << 8) | data[3]
+ s += ' ' + self.opcodes[opcode] + ' ' + str(block)
+ elif opcode == 5:
+ s += 'E ' + self.opcodes[opcode] + ', '
+ s += str((data[2] << 8) | (data[3]))
+ i = 2
+ while data[i] != 0:
+ s += chr(data[i])
+ i += 1
+ elif opcode == 6:
+ s += ' ' + self.opcodes[opcode]
+ i = 1
+ while i < dlen - 1:
+ s += ', '
+ i += 1
+ while data[i] != 0:
+ s += chr(data[i])
+ i += 1
+ else:
+ s += 'E INV(%d)' % (opcode)
+ else:
+ s += 'E INVALID LENGTH'
+ return s[:2] + '[%s:%d] ' % (host, port) + s[2:]
+
+ def get_option(self, option, data):
+ dlen = len(data)
+ opcode = (data[0] << 8) | data[1]
+ next_option = False
+ if opcode == 1 or opcode == 2:
+ i = 1
+ while i < dlen - 1:
+ o = ''
+ i += 1
+ while data[i] != 0:
+ o += chr(data[i])
+ i += 1
+ if o == option:
+ next_option = True
+ elif next_option:
+ return o
+ return None
+
+ def get_timeout(self, default_timeout, timeout_guard):
+ if self.timeout == 0:
+ return self.timeout + timeout_guard
+ return default_timeout
+
+ def get_block_size(self):
+ return self.block_size
+
+class udp_handler(socketserver.BaseRequestHandler):
+
+ def handle(self):
+ client_ip = self.client_address[0]
+ client_port = self.client_address[1]
+ client = '%s:%i' % (client_ip, client_port)
+ session = tftp_session()
+ finished = session.data(client_ip, client_port, self.request[0])
+ if not finished:
+ timeout = session.get_timeout(self.server.proxy.session_timeout, 1)
+ host = self.server.proxy.get_host(client_ip)
+ if host is not None:
+ session_count = self.server.proxy.get_session_count()
+ log.notice(' ] %6d: session: %s -> %s: start' % (session_count,
+ client,
+ host))
+ host_ip, host_server_port = host_port_split(host)
+ host_port = host_server_port
+ sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ sock.settimeout(timeout)
+ log.trace(' > ' + session.decode(client_ip,
+ client_port,
+ self.request[0]))
+ sock.sendto(self.request[0], (host_ip, host_port))
+ while not finished:
+ try:
+ data, address = sock.recvfrom(16 * 1024)
+ except socket.error as se:
+ log.notice(' ] session: %s -> %s: error: %s' % (client,
+ host,
+ se))
+ return
+ except socket.gaierror as se:
+ log.notice(' ] session: %s -> %s: error: %s' % (client,
+ host,
+ se))
+ return
+ except:
+ return
+ finished = session.data(address[0], address[1], data)
+ if address[0] == host_ip:
+ if host_port == host_server_port:
+ host_port = address[1]
+ if address[1] == host_port:
+ log.trace(' < ' + session.decode(address[0],
+ address[1],
+ data))
+ sock.sendto(data, (client_ip, client_port))
+ elif address[0] == client_ip and address[1] == client_port:
+ log.trace(' > ' + session.decode(address[0],
+ address[1],
+ data))
+ sock.sendto(data, (host_ip, host_port))
+ log.notice(' ] %6d: session: %s -> %s: end' % (session_count,
+ client,
+ host))
+ else:
+ mac = getmac.get_mac_address(ip = client_ip)
+ log.trace(' . request: host not found: %s (%s)' % (client, mac))
+
+class udp_server(socketserver.ThreadingMixIn, socketserver.UDPServer):
+ pass
+
+class proxy_server(object):
+ def __init__(self, config, host, port):
+ self.lock = threading.Lock()
+ self.session_timeout = 10
+ self.host = host
+ self.port = port
+ self.server = None
+ self.clients = { }
+ self.config = configuration.configuration()
+ self._load(config)
+ self.session_counter = 0
+
+ def __del__(self):
+ self.stop()
+
+ def _lock(self):
+ self.lock.acquire()
+
+ def _unlock(self):
+ self.lock.release()
+
+ def _load_client(self, client, depth = 0):
+ if depth > 32:
+ raise error.general('\'clients\'" nesting too deep; circular?')
+ if not self.config.has_section(client):
+ raise error.general('client not found: %s' % (client))
+ for c in self.config.comma_list(client, 'clients', err = False):
+ self._load_client(c, depth + 1)
+ if client in self.clients:
+ raise error.general('repeated client: %s' % (client))
+ host = self.config.get_item(client, 'host', err = False)
+ if host is not None:
+ ips = self.config.comma_list(client, 'ip', err = False)
+ macs = self.config.comma_list(client, 'mac', err = False)
+ if len(ips) != 0 and len(macs) != 0:
+ raise error.general('client has ip and mac: %s' % (client))
+ if len(ips) != 0:
+ keys = ips
+ elif len(macs) != 0:
+ keys = macs
+ else:
+ raise error.general('not client ip or mac: %s' % (client))
+ for key in keys:
+ self.clients[key] = host
+
+ def _load(self, config):
+ self.config.load(config)
+ clients = self.config.comma_list('default', 'clients', err = False)
+ if len(clients) == 0:
+ raise error.general('\'clients\'" entry not found in config [defaults]')
+ for client in clients:
+ self._load_client(client)
+
+ def start(self):
+ log.notice('Proxy: %s:%i' % (self.host, self.port))
+ if self.host == 'all':
+ host = ''
+ else:
+ host = self.host
+ try:
+ self.server = udp_server((host, self.port), udp_handler)
+ except Exception as e:
+ raise error.general('proxy create: %s' % (e))
+ self.server.proxy = self
+ self._lock()
+ try:
+ self.server_thread = threading.Thread(target = self.server.serve_forever)
+ self.server_thread.daemon = True
+ self.server_thread.start()
+ finally:
+ self._unlock()
+
+ def stop(self):
+ self._lock()
+ try:
+ if self.server is not None:
+ self.server.shutdown()
+ self.server.server_close()
+ self.server = None
+ finally:
+ self._unlock()
+
+ def run(self):
+ while True:
+ time.sleep(1)
+
+ def get_host(self, client):
+ host = None
+ self._lock()
+ try:
+ if client in self.clients:
+ host = self.clients[client]
+ else:
+ mac = getmac.get_mac_address(ip = client)
+ if mac in self.clients:
+ host = self.clients[mac]
+ finally:
+ self._unlock()
+ return host
+
+ def get_session_count(self):
+ count = 0
+ self._lock()
+ try:
+ self.session_counter += 1
+ count = self.session_counter
+ finally:
+ self._unlock()
+ return count
+
+
+def load_log(logfile):
+ if logfile is None:
+ log.default = log.log(streams = ['stdout'])
+ else:
+ log.default = log.log(streams = [logfile])
+
+def run(args = sys.argv, command_path = None):
+ ec = 0
+ notice = None
+ proxy = None
+ try:
+ description = 'Proxy TFTP sessions from the host running this proxy'
+ description += 'to hosts and ports defined in the configuration file. '
+ description += 'The tool lets you create a farm of hardware and to run '
+ description += 'more than one TFTP test session on a host or multiple '
+ description += 'hosts at once. This proxy service is not considered secure'
+ description += 'and is for use in a secure environment.'
+
+ argsp = argparse.ArgumentParser(prog = 'rtems-tftp-proxy',
+ description = description)
+ argsp.add_argument('-l', '--log',
+ help = 'log file.',
+ type = str, default = None)
+ argsp.add_argument('-v', '--trace',
+ help = 'enable trace logging for debugging.',
+ action = 'store_true', default = False)
+ argsp.add_argument('-c', '--config',
+ help = 'proxy configuation (default: %(default)s).',
+ type = str, default = None)
+ argsp.add_argument('-B', '--bind',
+ help = 'address to bind the proxy too (default: %(default)s).',
+ type = str, default = 'all')
+ argsp.add_argument('-P', '--port',
+ help = 'port to bind the proxy too(default: %(default)s).',
+ type = int, default = '69')
+
+ argopts = argsp.parse_args(args[1:])
+
+ load_log(argopts.log)
+ log.notice('RTEMS Tools - TFTP Proxy, %s' % (version.string()))
+ log.output(log.info(args))
+ log.tracing = argopts.trace
+
+ if argopts.config is None:
+ raise error.general('no config file, see -h')
+
+ proxy = proxy_server(argopts.config, argopts.bind, argopts.port)
+
+ try:
+ proxy.start()
+ proxy.run()
+ except:
+ proxy.stop()
+ raise
+
+ except error.general as gerr:
+ notice = str(gerr)
+ ec = 1
+ except error.internal as ierr:
+ notice = str(ierr)
+ ec = 1
+ except error.exit as eerr:
+ pass
+ except KeyboardInterrupt:
+ notice = 'abort: user terminated'
+ ec = 1
+ except:
+ raise
+ notice = 'abort: unknown error'
+ ec = 1
+ if proxy is not None:
+ del proxy
+ if notice is not None:
+ log.stderr(notice)
+ sys.exit(ec)
+
+if __name__ == "__main__":
+ run()