summaryrefslogtreecommitdiffstats
path: root/tester/rt/tftpy/TftpContexts.py
blob: 271441b13577e2c9b50ce152ff928c21894f2406 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
"""This module implements all contexts for state handling during uploads and
downloads, the main interface to which being the TftpContext base class.

The concept is simple. Each context object represents a single upload or
download, and the state object in the context object represents the current
state of that transfer. The state object has a handle() method that expects
the next packet in the transfer, and returns a state object until the transfer
is complete, at which point it returns None. That is, unless there is a fatal
error, in which case a TftpException is returned instead."""

from __future__ import absolute_import, division, print_function, unicode_literals
from .TftpShared import *
from .TftpPacketTypes import *
from .TftpPacketFactory import TftpPacketFactory
from .TftpStates import *
import socket, time, sys

###############################################################################
# Utility classes
###############################################################################

class TftpMetrics(object):
    """A class representing metrics of the transfer."""
    def __init__(self):
        # Bytes transferred
        self.bytes = 0
        # Bytes re-sent
        self.resent_bytes = 0
        # Duplicate packets received
        self.dups = {}
        self.dupcount = 0
        # Times
        self.start_time = 0
        self.end_time = 0
        self.duration = 0
        # Rates
        self.bps = 0
        self.kbps = 0
        # Generic errors
        self.errors = 0

    def compute(self):
        # Compute transfer time
        self.duration = self.end_time - self.start_time
        if self.duration == 0:
            self.duration = 1
        log.debug("TftpMetrics.compute: duration is %s", self.duration)
        self.bps = (self.bytes * 8.0) / self.duration
        self.kbps = self.bps / 1024.0
        log.debug("TftpMetrics.compute: kbps is %s", self.kbps)
        for key in self.dups:
            self.dupcount += self.dups[key]

    def add_dup(self, pkt):
        """This method adds a dup for a packet to the metrics."""
        log.debug("Recording a dup of %s", pkt)
        s = str(pkt)
        if s in self.dups:
            self.dups[s] += 1
        else:
            self.dups[s] = 1
        tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached")

###############################################################################
# Context classes
###############################################################################

class TftpContext(object):
    """The base class of the contexts."""

    def __init__(self, host, port, timeout, localip = ""):
        """Constructor for the base context, setting shared instance
        variables."""
        self.file_to_transfer = None
        self.fileobj = None
        self.options = None
        self.packethook = None
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        if localip != "":
            self.sock.bind((localip, 0))
        self.sock.settimeout(timeout)
        self.timeout = timeout
        self.state = None
        self.next_block = 0
        self.factory = TftpPacketFactory()
        # Note, setting the host will also set self.address, as it's a property.
        self.host = host
        self.port = port
        # The port associated with the TID
        self.tidport = None
        # Metrics
        self.metrics = TftpMetrics()
        # Fluag when the transfer is pending completion.
        self.pending_complete = False
        # Time when this context last received any traffic.
        # FIXME: does this belong in metrics?
        self.last_update = 0
        # The last packet we sent, if applicable, to make resending easy.
        self.last_pkt = None
        # Count the number of retry attempts.
        self.retry_count = 0

    def getBlocksize(self):
        """Fetch the current blocksize for this session."""
        return int(self.options.get('blksize', 512))

    def __del__(self):
        """Simple destructor to try to call housekeeping in the end method if
        not called explicitely. Leaking file descriptors is not a good
        thing."""
        self.end()

    def checkTimeout(self, now):
        """Compare current time with last_update time, and raise an exception
        if we're over the timeout time."""
        log.debug("checking for timeout on session %s", self)
        if now - self.last_update > self.timeout:
            raise TftpTimeout("Timeout waiting for traffic")

    def start(self):
        raise NotImplementedError("Abstract method")

    def end(self):
        """Perform session cleanup, since the end method should always be
        called explicitely by the calling code, this works better than the
        destructor."""
        log.debug("in TftpContext.end")
        self.sock.close()
        if self.fileobj is not None and not self.fileobj.closed:
            log.debug("self.fileobj is open - closing")
            self.fileobj.close()

    def gethost(self):
        "Simple getter method for use in a property."
        return self.__host

    def sethost(self, host):
        """Setter method that also sets the address property as a result
        of the host that is set."""
        self.__host = host
        self.address = socket.gethostbyname(host)

    host = property(gethost, sethost)

    def setNextBlock(self, block):
        if block >= 2 ** 16:
            log.debug("Block number rollover to 0 again")
            block = 0
        self.__eblock = block

    def getNextBlock(self):
        return self.__eblock

    next_block = property(getNextBlock, setNextBlock)

    def cycle(self):
        """Here we wait for a response from the server after sending it
        something, and dispatch appropriate action to that response."""
        try:
            (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
        except socket.timeout:
            log.warn("Timeout waiting for traffic, retrying...")
            raise TftpTimeout("Timed-out waiting for traffic")

        # Ok, we've received a packet. Log it.
        log.debug("Received %d bytes from %s:%s",
                        len(buffer), raddress, rport)
        # And update our last updated time.
        self.last_update = time.time()

        # Decode it.
        recvpkt = self.factory.parse(buffer)

        # Check for known "connection".
        if raddress != self.address:
            log.warn("Received traffic from %s, expected host %s. Discarding"
                        % (raddress, self.host))

        if self.tidport and self.tidport != rport:
            log.warn("Received traffic from %s:%s but we're "
                        "connected to %s:%s. Discarding."
                        % (raddress, rport,
                        self.host, self.tidport))

        # If there is a packethook defined, call it. We unconditionally
        # pass all packets, it's up to the client to screen out different
        # kinds of packets. This way, the client is privy to things like
        # negotiated options.
        if self.packethook:
            self.packethook(recvpkt)

        # And handle it, possibly changing state.
        self.state = self.state.handle(recvpkt, raddress, rport)
        # If we didn't throw any exceptions here, reset the retry_count to
        # zero.
        self.retry_count = 0

class TftpContextServer(TftpContext):
    """The context for the server."""
    def __init__(self,
                 host,
                 port,
                 timeout,
                 root,
                 dyn_file_func=None,
                 upload_open=None):
        TftpContext.__init__(self,
                             host,
                             port,
                             timeout,
                             )
        # At this point we have no idea if this is a download or an upload. We
        # need to let the start state determine that.
        self.state = TftpStateServerStart(self)

        self.root = root
        self.dyn_file_func = dyn_file_func
        self.upload_open = upload_open

    def __str__(self):
        return "%s:%s %s" % (self.host, self.port, self.state)

    def start(self, buffer):
        """Start the state cycle. Note that the server context receives an
        initial packet in its start method. Also note that the server does not
        loop on cycle(), as it expects the TftpServer object to manage
        that."""
        log.debug("In TftpContextServer.start")
        self.metrics.start_time = time.time()
        log.debug("Set metrics.start_time to %s", self.metrics.start_time)
        # And update our last updated time.
        self.last_update = time.time()

        pkt = self.factory.parse(buffer)
        log.debug("TftpContextServer.start() - factory returned a %s", pkt)

        # Call handle once with the initial packet. This should put us into
        # the download or the upload state.
        self.state = self.state.handle(pkt,
                                       self.host,
                                       self.port)

    def end(self):
        """Finish up the context."""
        TftpContext.end(self)
        self.metrics.end_time = time.time()
        log.debug("Set metrics.end_time to %s", self.metrics.end_time)
        self.metrics.compute()

class TftpContextClientUpload(TftpContext):
    """The upload context for the client during an upload.
    Note: If input is a hyphen, then we will use stdin."""
    def __init__(self,
                 host,
                 port,
                 filename,
                 input,
                 options,
                 packethook,
                 timeout,
                 localip = ""):
        TftpContext.__init__(self,
                             host,
                             port,
                             timeout,
                             localip)
        self.file_to_transfer = filename
        self.options = options
        self.packethook = packethook
        # If the input object has a read() function,
        # assume it is file-like.
        if hasattr(input, 'read'):
            self.fileobj = input
        elif input == '-':
            self.fileobj = sys.stdin
        else:
            self.fileobj = open(input, "rb")

        log.debug("TftpContextClientUpload.__init__()")
        log.debug("file_to_transfer = %s, options = %s" %
            (self.file_to_transfer, self.options))

    def __str__(self):
        return "%s:%s %s" % (self.host, self.port, self.state)

    def start(self):
        log.info("Sending tftp upload request to %s" % self.host)
        log.info("    filename -> %s" % self.file_to_transfer)
        log.info("    options -> %s" % self.options)

        self.metrics.start_time = time.time()
        log.debug("Set metrics.start_time to %s" % self.metrics.start_time)

        # FIXME: put this in a sendWRQ method?
        pkt = TftpPacketWRQ()
        pkt.filename = self.file_to_transfer
        pkt.mode = "octet" # FIXME - shouldn't hardcode this
        pkt.options = self.options
        self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
        self.next_block = 1
        self.last_pkt = pkt
        # FIXME: should we centralize sendto operations so we can refactor all
        # saving of the packet to the last_pkt field?

        self.state = TftpStateSentWRQ(self)

        while self.state:
            try:
                log.debug("State is %s" % self.state)
                self.cycle()
            except TftpTimeout as err:
                log.error(str(err))
                self.retry_count += 1
                if self.retry_count >= TIMEOUT_RETRIES:
                    log.debug("hit max retries, giving up")
                    raise
                else:
                    log.warn("resending last packet")
                    self.state.resendLast()

    def end(self):
        """Finish up the context."""
        TftpContext.end(self)
        self.metrics.end_time = time.time()
        log.debug("Set metrics.end_time to %s" % self.metrics.end_time)
        self.metrics.compute()


class TftpContextClientDownload(TftpContext):
    """The download context for the client during a download.
    Note: If output is a hyphen, then the output will be sent to stdout."""
    def __init__(self,
                 host,
                 port,
                 filename,
                 output,
                 options,
                 packethook,
                 timeout,
                 localip = ""):
        TftpContext.__init__(self,
                             host,
                             port,
                             timeout,
                             localip)
        # FIXME: should we refactor setting of these params?
        self.file_to_transfer = filename
        self.options = options
        self.packethook = packethook
        # If the output object has a write() function,
        # assume it is file-like.
        if hasattr(output, 'write'):
            self.fileobj = output
        # If the output filename is -, then use stdout
        elif output == '-':
            self.fileobj = sys.stdout
        else:
            self.fileobj = open(output, "wb")

        log.debug("TftpContextClientDownload.__init__()")
        log.debug("file_to_transfer = %s, options = %s" %
            (self.file_to_transfer, self.options))

    def __str__(self):
        return "%s:%s %s" % (self.host, self.port, self.state)

    def start(self):
        """Initiate the download."""
        log.info("Sending tftp download request to %s" % self.host)
        log.info("    filename -> %s" % self.file_to_transfer)
        log.info("    options -> %s" % self.options)

        self.metrics.start_time = time.time()
        log.debug("Set metrics.start_time to %s" % self.metrics.start_time)

        # FIXME: put this in a sendRRQ method?
        pkt = TftpPacketRRQ()
        pkt.filename = self.file_to_transfer
        pkt.mode = "octet" # FIXME - shouldn't hardcode this
        pkt.options = self.options
        self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
        self.next_block = 1
        self.last_pkt = pkt

        self.state = TftpStateSentRRQ(self)

        while self.state:
            try:
                log.debug("State is %s" % self.state)
                self.cycle()
            except TftpTimeout as err:
                log.error(str(err))
                self.retry_count += 1
                if self.retry_count >= TIMEOUT_RETRIES:
                    log.debug("hit max retries, giving up")
                    raise
                else:
                    log.warn("resending last packet")
                    self.state.resendLast()

    def end(self):
        """Finish up the context."""
        TftpContext.end(self)
        self.metrics.end_time = time.time()
        log.debug("Set metrics.end_time to %s" % self.metrics.end_time)
        self.metrics.compute()