summaryrefslogtreecommitdiffstats
path: root/tester/rt/tftpserver.py
diff options
context:
space:
mode:
Diffstat (limited to 'tester/rt/tftpserver.py')
-rw-r--r--tester/rt/tftpserver.py57
1 files changed, 50 insertions, 7 deletions
diff --git a/tester/rt/tftpserver.py b/tester/rt/tftpserver.py
index 92cd1fd..c200dad 100644
--- a/tester/rt/tftpserver.py
+++ b/tester/rt/tftpserver.py
@@ -453,14 +453,13 @@ class udp_handler(socketserver.BaseRequestHandler):
raise
self._notice('] tftp: %d: error: %s: %s' % (index, type(exp), exp))
self._notice('] tftp: %d: end: %s' % (index, client))
+ self.server.tftp.session_done()
def handle(self):
'''The UDP server handle method.'''
- if self.server.tftp.sessions is None \
- or self.server.tftp.session < self.server.tftp.sessions:
+ if self.server.tftp.sessions_available():
self.handle_session(self.server.tftp.next_session())
-
class udp_server(socketserver.ThreadingMixIn, socketserver.UDPServer):
'''UDP server. Default behaviour.'''
@@ -474,6 +473,7 @@ class tftp_server(object):
def __init__(self,
host,
port,
+ session_timeout=None,
timeout=10,
base=None,
forced_file=None,
@@ -484,6 +484,7 @@ class tftp_server(object):
self.notices = False
self.packet_trace = False
self.exception_is_raise = False
+ self.session_timeout = session_timeout
self.timeout = timeout
self.host = host
self.port = port
@@ -497,6 +498,7 @@ class tftp_server(object):
raise error.general('tftp session count is not a number')
self.sessions = sessions
self.session = 0
+ self.sessions_done = 0
self.reader = reader
def __del__(self):
@@ -542,6 +544,8 @@ class tftp_server(object):
def run(self):
'''Run the TFTP server for the specified number of sessions.'''
running = True
+ session_timeout = self.session_timeout
+ last_session = 0
while running:
period = 1
self._lock()
@@ -549,7 +553,7 @@ class tftp_server(object):
running = False
period = 0
elif self.sessions is not None:
- if self.sessions == 0:
+ if self.sessions_done >= self.sessions:
running = False
period = 0
else:
@@ -557,7 +561,24 @@ class tftp_server(object):
self._unlock()
if period > 0:
time.sleep(period)
+ if session_timeout is not None:
+ session = self.get_session()
+ if last_session != session:
+ last_session = session
+ session_timeout = self.session_timeout
+ else:
+ if session_timeout < period:
+ session_timeout = 0
+ else:
+ session_timeout -= period
+ if session_timeout == 0:
+ log.trace('] tftp: server: session timeout')
+ running = False
self.stop()
+ self._lock()
+ sessions_done = self.sessions_done
+ self._unlock()
+ return sessions_done
def get_session(self):
'''Return the session count.'''
@@ -580,6 +601,24 @@ class tftp_server(object):
self._unlock()
return count
+ def sessions_available(self):
+ '''Return True is there are available sessions.'''
+ available = False
+ self._lock()
+ try:
+ available = self.sessions is None or self.session < self.sessions
+ finally:
+ self._unlock()
+ return available
+
+ def session_done(self):
+ '''Call when a session is done.'''
+ self._lock()
+ try:
+ self.sessions_done += 1
+ finally:
+ self._unlock()
+
def enable_notices(self):
'''Call to enable notices. The server is quiet without this call.'''
self._lock()
@@ -654,10 +693,14 @@ def run(args=sys.argv, command_path=None):
help='port to bind the server too (default: %(default)s).',
type=int,
default='69')
+ argsp.add_argument('-S', '--session-timeout',
+ help='timeout in seconds, client can override ' \
+ '(default: %(default)s).',
+ type = int, default=None)
argsp.add_argument('-t', '--timeout',
- help = 'timeout in seconds, client can override ' \
+ help='timeout in seconds, client can override ' \
'(default: %(default)s).',
- type = int, default = '10')
+ type=int, default='10')
argsp.add_argument(
'-b',
'--base',
@@ -682,7 +725,7 @@ def run(args=sys.argv, command_path=None):
log.output(log.info(args))
log.tracing = argopts.trace
- server = tftp_server(argopts.bind, argopts.port, argopts.timeout,
+ server = tftp_server(argopts.bind, argopts.port, argopts.session_timeout, argopts.timeout,
argopts.base, argopts.force_file,
argopts.sessions)
server.enable_notices()