summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--tester/rt/tftpserver.py32
1 files changed, 28 insertions, 4 deletions
diff --git a/tester/rt/tftpserver.py b/tester/rt/tftpserver.py
index 012aae4..d733301 100644
--- a/tester/rt/tftpserver.py
+++ b/tester/rt/tftpserver.py
@@ -82,6 +82,7 @@ class tftp_session(object):
self.resends_limit = 5
# These are here to shut pylint up
self.block = 0
+ self.last_data = None
self.block_size = 512
self.timeout = 0
self.resends = 0
@@ -95,6 +96,7 @@ class tftp_session(object):
self.filein.close()
self.filein = None
self.block = 0
+ self.last_data = None
self.block_size = 512
self.timeout = 0
self.resends = 0
@@ -166,7 +168,7 @@ class tftp_session(object):
raise error.general('invalid opcode: ' + opcode)
bdata = self._pack_bytes([code, data])
#print(''.join(format(x, '02x') for x in bdata))
- return bytes(bdata)
+ return bdata
def _error_response(self, code, message):
if log.tracing:
@@ -193,16 +195,19 @@ class tftp_session(object):
self.resends = 0
self.block += 1
err, data = self.data_reader('read', blksize=self.block_size)
+ data = bytearray(data)
if err != self.E_NO_ERROR:
return self._error_response(err, data)
# close if the length of data is less than the block size
if len(data) < self.block_size:
self.data_reader('close')
+ self.last_data = data
else:
self.resends += 1
if self.resends >= self.resends_limit:
return self._error_response(self.E_NOT_DEFINED,
'resend limit reached')
+ data = self.last_data
return self._data_response(self.block, data)
def _read_req(self, data):
@@ -387,6 +392,7 @@ class udp_handler(socketserver.BaseRequestHandler):
# pylint: disable=too-many-locals
# pylint: disable=broad-except
# pylint: disable=too-many-branches
+ # pylint: disable=too-many-statements
client_ip = self.client_address[0]
client_port = self.client_address[1]
client = '%s:%i' % (client_ip, client_port)
@@ -396,11 +402,12 @@ class udp_handler(socketserver.BaseRequestHandler):
self.server.tftp.base,
self.server.tftp.forced_file,
self.server.tftp.reader)
- response = session.process(client_ip, client_port, self.request[0])
+ data = bytearray(self.request[0])
+ response = session.process(client_ip, client_port, data)
if response is not None:
if log.tracing and self.server.tftp.packet_trace:
- log.trace(' > ' + session.decode(client_ip, client_port,
- self.request[0]))
+ log.trace(' > ' +
+ session.decode(client_ip, client_port, data))
timeout = session.get_timeout(self.server.tftp.timeout, 1)
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.bind(('', 0))
@@ -416,6 +423,7 @@ class udp_handler(socketserver.BaseRequestHandler):
try:
data, address = sock.recvfrom(2 + 2 +
session.get_block_size())
+ data = bytearray(data)
if log.tracing and self.server.tftp.packet_trace:
log.trace(
' > ' +
@@ -440,6 +448,8 @@ class udp_handler(socketserver.BaseRequestHandler):
except KeyboardInterrupt:
pass
except Exception as exp:
+ if self.server.tftp.exception_is_raise:
+ raise
self._notice('] tftp: %d: error: %s: %s' % (index, type(exp), exp))
self._notice('] tftp: %d: end: %s' % (index, client))
@@ -472,6 +482,7 @@ class tftp_server(object):
self.lock = threading.Lock()
self.notices = False
self.packet_trace = False
+ self.exception_is_raise = False
self.timeout = timeout
self.host = host
self.port = port
@@ -580,6 +591,10 @@ class tftp_server(object):
self.packet_trace = True
self._unlock()
+ def except_is_raise(self):
+ '''If True a standard exception will generate a backtrace.'''
+ self.exception_is_raise = True
+
def load_log(logfile):
'''Set the log file.'''
@@ -593,6 +608,7 @@ def run(args=sys.argv, command_path=None):
'''Run a TFTP server session.'''
# pylint: disable=dangerous-default-value
# pylint: disable=unused-argument
+ # pylint: disable=too-many-branches
# pylint: disable=too-many-statements
ecode = 0
notice = None
@@ -621,6 +637,10 @@ def run(args=sys.argv, command_path=None):
help='enable trace logging of packets.',
action='store_true',
default=False)
+ argsp.add_argument('--show-backtrace',
+ help='show the exception backtrace.',
+ action='store_true',
+ default=False)
argsp.add_argument(
'-B',
'--bind',
@@ -665,6 +685,10 @@ def run(args=sys.argv, command_path=None):
argopts.base, argopts.force_file,
argopts.sessions)
server.enable_notices()
+ if argopts.trace_packets:
+ server.trace_packets()
+ if argopts.show_backtrace:
+ server.except_is_raise()
try:
server.start()