summaryrefslogtreecommitdiff
path: root/tester/rt/tftpy/TftpContexts.py
diff options
context:
space:
mode:
Diffstat (limited to 'tester/rt/tftpy/TftpContexts.py')
-rw-r--r--tester/rt/tftpy/TftpContexts.py47
1 files changed, 35 insertions, 12 deletions
diff --git a/tester/rt/tftpy/TftpContexts.py b/tester/rt/tftpy/TftpContexts.py
index 271441b..da85886 100644
--- a/tester/rt/tftpy/TftpContexts.py
+++ b/tester/rt/tftpy/TftpContexts.py
@@ -1,3 +1,5 @@
+# vim: ts=4 sw=4 et ai:
+# -*- coding: utf8 -*-
"""This module implements all contexts for state handling during uploads and
downloads, the main interface to which being the TftpContext base class.
@@ -8,12 +10,18 @@ the next packet in the transfer, and returns a state object until the transfer
is complete, at which point it returns None. That is, unless there is a fatal
error, in which case a TftpException is returned instead."""
-from __future__ import absolute_import, division, print_function, unicode_literals
+
from .TftpShared import *
from .TftpPacketTypes import *
from .TftpPacketFactory import TftpPacketFactory
from .TftpStates import *
-import socket, time, sys
+import socket
+import time
+import sys
+import os
+import logging
+
+log = logging.getLogger('tftpy.TftpContext')
###############################################################################
# Utility classes
@@ -120,13 +128,14 @@ class TftpContext(object):
def start(self):
raise NotImplementedError("Abstract method")
- def end(self):
+ def end(self, close_fileobj=True):
"""Perform session cleanup, since the end method should always be
called explicitely by the calling code, this works better than the
- destructor."""
- log.debug("in TftpContext.end")
+ destructor.
+ Set close_fileobj to False so fileobj can be returned open."""
+ log.debug("in TftpContext.end - closing socket")
self.sock.close()
- if self.fileobj is not None and not self.fileobj.closed:
+ if close_fileobj and self.fileobj is not None and not self.fileobj.closed:
log.debug("self.fileobj is open - closing")
self.fileobj.close()
@@ -159,7 +168,7 @@ class TftpContext(object):
try:
(buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
except socket.timeout:
- log.warn("Timeout waiting for traffic, retrying...")
+ log.warning("Timeout waiting for traffic, retrying...")
raise TftpTimeout("Timed-out waiting for traffic")
# Ok, we've received a packet. Log it.
@@ -173,11 +182,11 @@ class TftpContext(object):
# Check for known "connection".
if raddress != self.address:
- log.warn("Received traffic from %s, expected host %s. Discarding"
+ log.warning("Received traffic from %s, expected host %s. Discarding"
% (raddress, self.host))
if self.tidport and self.tidport != rport:
- log.warn("Received traffic from %s:%s but we're "
+ log.warning("Received traffic from %s:%s but we're "
"connected to %s:%s. Discarding."
% (raddress, rport,
self.host, self.tidport))
@@ -315,7 +324,7 @@ class TftpContextClientUpload(TftpContext):
log.debug("hit max retries, giving up")
raise
else:
- log.warn("resending last packet")
+ log.warning("resending last packet")
self.state.resendLast()
def end(self):
@@ -347,13 +356,16 @@ class TftpContextClientDownload(TftpContext):
self.file_to_transfer = filename
self.options = options
self.packethook = packethook
+ self.filelike_fileobj = False
# If the output object has a write() function,
# assume it is file-like.
if hasattr(output, 'write'):
self.fileobj = output
+ self.filelike_fileobj = True
# If the output filename is -, then use stdout
elif output == '-':
self.fileobj = sys.stdout
+ self.filelike_fileobj = True
else:
self.fileobj = open(output, "wb")
@@ -395,12 +407,23 @@ class TftpContextClientDownload(TftpContext):
log.debug("hit max retries, giving up")
raise
else:
- log.warn("resending last packet")
+ log.warning("resending last packet")
self.state.resendLast()
+ except TftpFileNotFoundError as err:
+ # If we received file not found, then we should not save the open
+ # output file or we'll be left with a size zero file. Delete it,
+ # if it exists.
+ log.error("Received File not found error")
+ if self.fileobj is not None and not self.filelike_fileobj:
+ if os.path.exists(self.fileobj.name):
+ log.debug("unlinking output file of %s", self.fileobj.name)
+ os.unlink(self.fileobj.name)
+
+ raise
def end(self):
"""Finish up the context."""
- TftpContext.end(self)
+ TftpContext.end(self, not self.filelike_fileobj)
self.metrics.end_time = time.time()
log.debug("Set metrics.end_time to %s" % self.metrics.end_time)
self.metrics.compute()