diff options
author | Chris Johns <chrisj@rtems.org> | 2020-08-26 13:38:54 +1000 |
---|---|---|
committer | Chris Johns <chrisj@rtems.org> | 2020-09-01 09:18:04 +1000 |
commit | eb3608133b41b9cb7b4dd55cb15d77691d2300c6 (patch) | |
tree | 3361e303477fdbc4a6a63f96054d5bcdf2ada133 /tester/rt/tftpserver.py | |
parent | trace: Use c++14 instead of c++11 if possible (diff) | |
download | rtems-tools-eb3608133b41b9cb7b4dd55cb15d77691d2300c6.tar.bz2 |
tester: Change to a simpler TFTP server
- Add a simpler TFTP to allow parallel test hardware
- Remove the imported tftpy server
Closes #4063
Diffstat (limited to 'tester/rt/tftpserver.py')
-rw-r--r-- | tester/rt/tftpserver.py | 699 |
1 files changed, 699 insertions, 0 deletions
diff --git a/tester/rt/tftpserver.py b/tester/rt/tftpserver.py new file mode 100644 index 0000000..012aae4 --- /dev/null +++ b/tester/rt/tftpserver.py @@ -0,0 +1,699 @@ +# SPDX-License-Identifier: BSD-2-Clause +'''The TFTP Server handles a read only TFTP session.''' + +# Copyright (C) 2020 Chris Johns (chrisj@rtems.org) +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# 1. Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# 2. Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +# POSSIBILITY OF SUCH DAMAGE. + +from __future__ import print_function + +import argparse +import os +import socket +import sys +import time +import threading + +try: + import socketserver +except ImportError: + import SocketServer as socketserver + +from rtemstoolkit import error +from rtemstoolkit import log +from rtemstoolkit import version + + +class tftp_session(object): + '''Handle the TFTP session packets initiated on the TFTP port (69). + ''' + # pylint: disable=useless-object-inheritance + # pylint: disable=too-many-instance-attributes + + opcodes = ['nul', 'RRQ', 'WRQ', 'DATA', 'ACK', 'ERROR', 'OACK'] + + OP_RRQ = 1 + OP_WRQ = 2 + OP_DATA = 3 + OP_ACK = 4 + OP_ERROR = 5 + OP_OACK = 6 + + E_NOT_DEFINED = 0 + E_FILE_NOT_FOUND = 1 + E_ACCESS_VIOLATION = 2 + E_DISK_FULL = 3 + E_ILLEGAL_TFTP_OP = 4 + E_UKNOWN_TID = 5 + E_FILE_ALREADY_EXISTS = 6 + E_NO_SUCH_USER = 7 + E_NO_ERROR = 10 + + def __init__(self, host, port, base, forced_file, reader=None): + # pylint: disable=too-many-arguments + self.host = host + self.port = port + self.base = base + self.forced_file = forced_file + if reader is None: + self.data_reader = self._file_reader + else: + self.data_reader = reader + self.filein = None + self.resends_limit = 5 + # These are here to shut pylint up + self.block = 0 + self.block_size = 512 + self.timeout = 0 + self.resends = 0 + self.finished = False + self.filename = None + self._reinit() + + def _reinit(self): + '''Reinitialise all the class variables used by the protocol.''' + if self.filein is not None: + self.filein.close() + self.filein = None + self.block = 0 + self.block_size = 512 + self.timeout = 0 + self.resends = 0 + self.finished = False + self.filename = None + + def _file_reader(self, command, **kwargs): + '''The default file reader if the user does not provide one. + + The call returns a two element tuple where the first element + is an error code, and the second element is data if the error + code is 0 else it is an error message. + ''' + # pylint: disable=too-many-return-statements + if command == 'open': + if 'filename' not in kwargs: + raise error.general('tftp-reader: invalid open: no filename') + filename = kwargs['filename'] + try: + self.filein = open(filename, 'rb') + filesize = os.stat(filename).st_size + except FileNotFoundError: + return self.E_FILE_NOT_FOUND, 'file not found (%s)' % ( + filename) + except PermissionError: + return self.E_ACCESS_VIOLATION, 'access violation' + except IOError as ioe: + return self.E_NOT_DEFINED, str(ioe) + return self.E_NO_ERROR, str(filesize) + if command == 'read': + if self.filein is None: + raise error.general('tftp-reader: read when not open') + if 'blksize' not in kwargs: + raise error.general('tftp-reader: invalid read: no blksize') + # pylint: disable=bare-except + try: + return self.E_NO_ERROR, self.filein.read(kwargs['blksize']) + except IOError as ioe: + return self.E_NOT_DEFINED, str(ioe) + except: + return self.E_NOT_DEFINED, 'unknown error' + if command == 'close': + if self.filein is not None: + self.filein.close() + self.filein = None + return self.E_NO_ERROR, "closed" + return self.E_NOT_DEFINED, 'invalid reader state' + + @staticmethod + def _pack_bytes(data=None): + bdata = bytearray() + if data is not None: + if not isinstance(data, list): + data = [data] + for item in data: + if isinstance(item, int): + bdata.append(item >> 8) + bdata.append(item & 0xff) + elif isinstance(item, str): + bdata.extend(item.encode()) + bdata.append(0) + else: + bdata.extend(item) + return bdata + + def _response(self, opcode, data): + code = self.opcodes.index(opcode) + if code == 0 or code >= len(self.opcodes): + raise error.general('invalid opcode: ' + opcode) + bdata = self._pack_bytes([code, data]) + #print(''.join(format(x, '02x') for x in bdata)) + return bytes(bdata) + + def _error_response(self, code, message): + if log.tracing: + log.trace('tftp: error: %s:%d: %d: %s' % + (self.host, self.port, code, message)) + self.finished = True + return self._response('ERROR', self._pack_bytes([code, message, 0])) + + def _data_response(self, block, data): + if len(data) < self.block_size: + self.finished = True + return self._response('DATA', self._pack_bytes([block, data])) + + def _oack_response(self, data): + self.resends += 1 + if self.resends >= self.resends_limit: + return self._error_response(self.E_NOT_DEFINED, + 'resend limit reached') + return self._response('OACK', self._pack_bytes(data)) + + def _next_block(self, block): + # has the current block been acknowledged? + if block == self.block: + self.resends = 0 + self.block += 1 + err, data = self.data_reader('read', blksize=self.block_size) + if err != self.E_NO_ERROR: + return self._error_response(err, data) + # close if the length of data is less than the block size + if len(data) < self.block_size: + self.data_reader('close') + else: + self.resends += 1 + if self.resends >= self.resends_limit: + return self._error_response(self.E_NOT_DEFINED, + 'resend limit reached') + return self._data_response(self.block, data) + + def _read_req(self, data): + # if the last block is not 0 something has gone wrong and + # TID match. Restart the session. It could be the client + # is a simple implementation that does not move the send + # port on each retry. + if self.block != 0: + self.data_reader('close') + self._reinit() + # Get the filename, mode and options + self.filename = self.get_option('filename', data) + if self.filename is None: + return self._error_response(self.E_NOT_DEFINED, + 'filename not found in request') + if self.forced_file is not None: + self.filename = self.forced_file + # open the reader + err, message = self.data_reader('open', filename=self.filename) + if err != self.E_NO_ERROR: + return self._error_response(err, message) + # the no error on open message is the file size + try: + tsize = int(message) + except ValueError: + tsize = 0 + mode = self.get_option('mode', data) + if mode is None: + return self._error_response(self.E_NOT_DEFINED, + 'mode not found in request') + oack_data = self._pack_bytes() + value = self.get_option('timeout', data) + if value is not None: + oack_data += self._pack_bytes(['timeout', value]) + self.timeout = int(value) + value = self.get_option('blksize', data) + if value is not None: + oack_data += self._pack_bytes(['blksize', value]) + self.block_size = int(value) + else: + self.block_size = 512 + value = self.get_option('tsize', data) + if value is not None and tsize > 0: + oack_data += self._pack_bytes(['tsize', str(tsize)]) + # Send the options ack + return self._oack_response(oack_data) + + def _write_req(self): + # WRQ is not supported + return self._error_response(self.E_ILLEGAL_TFTP_OP, + "writes not supported") + + def _op_ack(self, data): + # send the next block of data + block = (data[2] << 8) | data[3] + return self._next_block(block) + + def process(self, host, port, data): + '''Process the incoming client data sending a response. If the session + has finished return None. + ''' + if host != self.host and port != self.port: + return self._error_response(self.E_UKNOWN_TID, + 'unkown transfer ID') + if self.finished: + return None + opcode = (data[0] << 8) | data[1] + if opcode == self.OP_RRQ: + return self._read_req(data) + if opcode in [self.OP_WRQ, self.OP_DATA]: + return self._write_req() + if opcode == self.OP_ACK: + return self._op_ack(data) + return self._error_response(self.E_ILLEGAL_TFTP_OP, + "unknown or unsupported opcode") + + def decode(self, host, port, data): + '''Decode the packet for diagnostic purposes. + ''' + # pylint: disable=too-many-branches + out = '' + dlen = len(data) + if dlen > 2: + opcode = (data[0] << 8) | data[1] + if 0 < opcode < len(self.opcodes): + if opcode in [self.OP_RRQ, self.OP_WRQ]: + out += ' ' + self.opcodes[opcode] + ', ' + i = 2 + while data[i] != 0: + out += chr(data[i]) + i += 1 + while i < dlen - 1: + out += ', ' + i += 1 + while data[i] != 0: + out += chr(data[i]) + i += 1 + elif opcode == self.OP_DATA: + block = (data[2] << 8) | data[3] + out += ' ' + self.opcodes[opcode] + ', ' + out += '#' + str(block) + ', ' + if dlen > 4: + out += '%02x%02x..%02x%02x' % (data[4], data[5], + data[-2], data[-1]) + else: + out += '%02x%02x%02x%02x' % (data[4], data[5], data[6], + data[6]) + out += ',' + str(dlen - 4) + elif opcode == self.OP_ACK: + block = (data[2] << 8) | data[3] + out += ' ' + self.opcodes[opcode] + ' ' + str(block) + elif opcode == self.OP_ERROR: + out += 'E ' + self.opcodes[opcode] + ', ' + out += str((data[2] << 8) | (data[3])) + out += ': ' + str(data[4:].decode()) + i = 2 + while data[i] != 0: + out += chr(data[i]) + i += 1 + elif opcode == self.OP_OACK: + out += ' ' + self.opcodes[opcode] + i = 1 + while i < dlen - 1: + out += ', ' + i += 1 + while data[i] != 0: + out += chr(data[i]) + i += 1 + else: + out += 'E INV(%d)' % (opcode) + else: + out += 'E INVALID LENGTH' + return out[:2] + '[%s:%d] (%d) ' % (host, port, len(data)) + out[2:] + + @staticmethod + def get_option(option, data): + '''Get the option from the TFTP packet.''' + dlen = len(data) - 1 + opcode = (data[0] << 8) | data[1] + next_option = False + if opcode in [1, 2]: + count = 0 + i = 2 + while i < dlen: + value = '' + while data[i] != 0: + value += chr(data[i]) + i += 1 + i += 1 + if option == 'filename' and count == 0: + return value + if option == 'mode' and count == 1: + return value + if value == option and (count % 1) == 0: + next_option = True + elif next_option: + return value + count += 1 + return None + + def get_timeout(self, default_timeout, timeout_guard): + '''Get the timeout. The timeout can be an option.''' + if self.timeout == 0: + return self.timeout + timeout_guard + return default_timeout + + def get_block_size(self): + '''Get the block size. The block size can be an option.''' + return self.block_size + + +class udp_handler(socketserver.BaseRequestHandler): + '''TFTP UDP handler for a TFTP session.''' + def _notice(self, text): + if self.server.tftp.notices: + log.notice(text) + else: + log.trace(text) + + def handle_session(self, index): + '''Handle the TFTP session data.''' + # pylint: disable=too-many-locals + # pylint: disable=broad-except + # pylint: disable=too-many-branches + client_ip = self.client_address[0] + client_port = self.client_address[1] + client = '%s:%i' % (client_ip, client_port) + self._notice('] tftp: %d: start: %s' % (index, client)) + try: + session = tftp_session(client_ip, client_port, + self.server.tftp.base, + self.server.tftp.forced_file, + self.server.tftp.reader) + response = session.process(client_ip, client_port, self.request[0]) + if response is not None: + if log.tracing and self.server.tftp.packet_trace: + log.trace(' > ' + session.decode(client_ip, client_port, + self.request[0])) + timeout = session.get_timeout(self.server.tftp.timeout, 1) + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.bind(('', 0)) + sock.settimeout(timeout) + while response is not None: + if log.tracing and self.server.tftp.packet_trace: + log.trace( + ' < ' + + session.decode(client_ip, client_port, response)) + sock.sendto(response, (client_ip, client_port)) + if session.finished: + break + try: + data, address = sock.recvfrom(2 + 2 + + session.get_block_size()) + if log.tracing and self.server.tftp.packet_trace: + log.trace( + ' > ' + + session.decode(address[0], address[1], data)) + except socket.error as serr: + if log.tracing: + log.trace('] tftp: %d: receive: %s: error: %s' \ + % (index, client, serr)) + return + except socket.gaierror as serr: + if log.tracing: + log.trace('] tftp: %d: receive: %s: error: %s' \ + % (index, client, serr)) + return + response = session.process(address[0], address[1], data) + except error.general as gerr: + self._notice('] tftp: %dd: error: %s' % (index, gerr)) + except error.internal as ierr: + self._notice('] tftp: %d: error: %s' % (index, ierr)) + except error.exit: + pass + except KeyboardInterrupt: + pass + except Exception as exp: + self._notice('] tftp: %d: error: %s: %s' % (index, type(exp), exp)) + self._notice('] tftp: %d: end: %s' % (index, client)) + + def handle(self): + '''The UDP server handle method.''' + if self.server.tftp.sessions is None \ + or self.server.tftp.session < self.server.tftp.sessions: + self.handle_session(self.server.tftp.next_session()) + + +class udp_server(socketserver.ThreadingMixIn, socketserver.UDPServer): + '''UDP server. Default behaviour.''' + + +class tftp_server(object): + '''TFTP server runs a UDP server to handle TFTP sessions.''' + + # pylint: disable=useless-object-inheritance + # pylint: disable=too-many-instance-attributes + + def __init__(self, + host, + port, + timeout=10, + base=None, + forced_file=None, + sessions=None, + reader=None): + # pylint: disable=too-many-arguments + self.lock = threading.Lock() + self.notices = False + self.packet_trace = False + self.timeout = timeout + self.host = host + self.port = port + self.server = None + self.server_thread = None + if base is None: + base = os.getcwd() + self.base = base + self.forced_file = forced_file + if sessions is not None and not isinstance(sessions, int): + raise error.general('tftp session count is not a number') + self.sessions = sessions + self.session = 0 + self.reader = reader + + def __del__(self): + self.stop() + + def _lock(self): + self.lock.acquire() + + def _unlock(self): + self.lock.release() + + def start(self): + '''Start the TFTP server. Returns once started.''' + # pylint: disable=attribute-defined-outside-init + if log.tracing: + log.trace('] tftp: server: %s:%i' % (self.host, self.port)) + if self.host == 'all': + host = '' + else: + host = self.host + try: + self.server = udp_server((host, self.port), udp_handler) + except Exception as exp: + raise error.general('tftp server create: %s' % (exp)) + # We cannot set tftp in __init__ because the object is created + # in a separate package. + self.server.tftp = self + self.server_thread = threading.Thread(target=self.server.serve_forever) + self.server_thread.daemon = True + self.server_thread.start() + + def stop(self): + '''Stop the TFTP server and close the server port.''' + self._lock() + try: + if self.server is not None: + self.server.shutdown() + self.server.server_close() + self.server = None + finally: + self._unlock() + + def run(self): + '''Run the TFTP server for the specified number of sessions.''' + running = True + while running: + period = 1 + self._lock() + if self.server is None: + running = False + period = 0 + elif self.sessions is not None: + if self.sessions == 0: + running = False + period = 0 + else: + period = 0.25 + self._unlock() + if period > 0: + time.sleep(period) + self.stop() + + def get_session(self): + '''Return the session count.''' + count = 0 + self._lock() + try: + count = self.session + finally: + self._unlock() + return count + + def next_session(self): + '''Return the next session number.''' + count = 0 + self._lock() + try: + self.session += 1 + count = self.session + finally: + self._unlock() + return count + + def enable_notices(self): + '''Call to enable notices. The server is quiet without this call.''' + self._lock() + self.notices = True + self._unlock() + + def trace_packets(self): + '''Call to enable packet tracing as a diagnostic.''' + self._lock() + self.packet_trace = True + self._unlock() + + +def load_log(logfile): + '''Set the log file.''' + if logfile is None: + log.default = log.log(streams=['stdout']) + else: + log.default = log.log(streams=[logfile]) + + +def run(args=sys.argv, command_path=None): + '''Run a TFTP server session.''' + # pylint: disable=dangerous-default-value + # pylint: disable=unused-argument + # pylint: disable=too-many-statements + ecode = 0 + notice = None + server = None + # pylint: disable=bare-except + try: + description = 'A TFTP Server that supports a read only TFTP session.' + + nice_cwd = os.path.relpath(os.getcwd()) + if len(nice_cwd) > len(os.path.abspath(nice_cwd)): + nice_cwd = os.path.abspath(nice_cwd) + + argsp = argparse.ArgumentParser(prog='rtems-tftp-server', + description=description) + argsp.add_argument('-l', + '--log', + help='log file.', + type=str, + default=None) + argsp.add_argument('-v', + '--trace', + help='enable trace logging for debugging.', + action='store_true', + default=False) + argsp.add_argument('--trace-packets', + help='enable trace logging of packets.', + action='store_true', + default=False) + argsp.add_argument( + '-B', + '--bind', + help='address to bind the server too (default: %(default)s).', + type=str, + default='all') + argsp.add_argument( + '-P', + '--port', + help='port to bind the server too (default: %(default)s).', + type=int, + default='69') + argsp.add_argument('-t', '--timeout', + help = 'timeout in seconds, client can override ' \ + '(default: %(default)s).', + type = int, default = '10') + argsp.add_argument( + '-b', + '--base', + help='base path, not checked (default: %(default)s).', + type=str, + default=nice_cwd) + argsp.add_argument( + '-F', + '--force-file', + help='force the file to be downloaded overriding the client.', + type=str, + default=None) + argsp.add_argument('-s', '--sessions', + help = 'number of TFTP sessions to run before exiting ' \ + '(default: forever.', + type = int, default = None) + + argopts = argsp.parse_args(args[1:]) + + load_log(argopts.log) + log.notice('RTEMS Tools - TFTP Server, %s' % (version.string())) + log.output(log.info(args)) + log.tracing = argopts.trace + + server = tftp_server(argopts.bind, argopts.port, argopts.timeout, + argopts.base, argopts.force_file, + argopts.sessions) + server.enable_notices() + + try: + server.start() + server.run() + finally: + server.stop() + + except error.general as gerr: + notice = str(gerr) + ecode = 1 + except error.internal as ierr: + notice = str(ierr) + ecode = 1 + except error.exit: + pass + except KeyboardInterrupt: + notice = 'abort: user terminated' + ecode = 1 + except SystemExit: + pass + except: + notice = 'abort: unknown error' + ecode = 1 + if server is not None: + del server + if notice is not None: + log.stderr(notice) + sys.exit(ecode) + + +if __name__ == "__main__": + run() |