Update preset.
[bertos.git] / bertos / net / tftp.c
1 /**
2  * \file
3  * <!--
4  * This file is part of BeRTOS.
5  *
6  * Bertos is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * This program is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this program; if not, write to the Free Software
18  * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
19  *
20  * As a special exception, you may use this file as part of a free software
21  * library without restriction.  Specifically, if other files instantiate
22  * templates or use macros or inline functions from this file, or you compile
23  * this file and link it with other files to produce an executable, this
24  * file does not by itself cause the resulting executable to be covered by
25  * the GNU General Public License.  This exception does not however
26  * invalidate any other reasons why the executable file might be covered by
27  * the GNU General Public License.
28  *
29  * Copyright 2010 Develer S.r.l. (http://www.develer.com/)
30  *
31  * -->
32  *
33  * \brief TFTP protocol implementation
34  *
35  * \author Luca Ottaviano <lottaviano@develer.com>
36  *
37  * notest:all
38  */
39
40 #include "tftp.h"
41 #include "cfg/cfg_tftp.h"
42 #define LOG_LEVEL   TFTP_LOG_LEVEL
43 #define LOG_FORMAT  TFTP_LOG_FORMAT
44 #include <cfg/log.h>
45
46 //#include <lwip/in.h>
47 #include <lwip/inet.h>
48 #include <lwip/sockets.h>
49 #include <string.h> //memset
50
51 #define TFTP_PACKET_SIZE 516
52
53 #define DECLARE_TIMEOUT(name, timeout) \
54         struct timeval name; \
55         name.tv_sec = timeout / 1000; \
56         name.tv_usec = (timeout % 1000) * 1000;
57
58 #define KFT_TFTPSESSION MAKE_ID('T', 'F', 'T', 'P')
59 INLINE TftpSession *TFTP_CAST(KFile *fd)
60 {
61         ASSERT(fd->_type == KFT_TFTPSESSION);
62         return (TftpSession *)containerof(fd, TftpSession, kfile_request);
63 }
64
65 /*
66  * Check if received data is correct and send ACK if ok.
67  */
68 static int checkPacket(TftpSession *ctx, const Tftpframe *frame)
69 {
70         LOG_INFO("Checking block %hd\n", ctx->block);
71         if (ntohs(frame->hdr.opcode) != TFTP_DATA)
72         {
73                 LOG_INFO("Opcode != TFTP_DATA (%hd != %d)\n", ntohs(frame->hdr.opcode), TFTP_DATA);
74                 return -1;
75         }
76         if (ntohs(frame->hdr.th_u.block) != ctx->block + 1)
77                 return -1;
78
79         ctx->block++;
80         // if everything was ok, send ACK
81         // ACK is already in network order
82         struct ackframe ack;
83         ack.opcode = TFTP_ACK;
84         ack.block_num = htons(ctx->block);
85         ssize_t rc = lwip_sendto(ctx->sock, &ack, 4, 0, (struct sockaddr *)&ctx->addr, ctx->addr_len);
86         if (rc == 4)
87                 return rc;
88         else
89                 return -1;
90 }
91
92 /*
93  * Return >0 if there's something to read in ctx, 0 on timeout, -1 on errors
94  */
95 static int tftp_waitEvent(TftpSession *ctx, struct timeval *timeout)
96 {
97         fd_set inset;
98         FD_ZERO(&inset);
99         FD_SET(ctx->sock, &inset);
100         struct timeval tmp = *timeout;
101         return lwip_select(ctx->sock + 1, &inset, NULL, NULL, &tmp);
102 }
103
104 /*
105  * Read a block from TFTP.
106  * \param size Must be exactly 516 bytes
107  * \param timeout Time to wait the network connection, may be NULL to wait forever
108  * \return Number of bytes read if success, TFTP_ERR_TIMEOUT on timeout, TFTP_ERR otherwise
109  */
110 static ssize_t tftp_readPacket(TftpSession *ctx, Tftpframe *frame, mtime_t timeout)
111 {
112         DECLARE_TIMEOUT(wait_tm, timeout);
113
114         int res = tftp_waitEvent(ctx, &wait_tm);
115         if (res == 0)
116                 return TFTP_ERR_TIMEOUT;
117         if (res == -1)
118                 return TFTP_ERR;
119
120         ssize_t rlen = lwip_recvfrom(ctx->sock, frame, sizeof(Tftpframe), 0, NULL, NULL);
121         LOG_INFO("Received %zd bytes\n", rlen);
122         if (rlen > 0 && (checkPacket(ctx, frame) > 0))
123                 return rlen;
124         else
125                 return TFTP_ERR;
126 }
127
128 static size_t tftp_read(struct KFile *fd, void *buf, size_t size)
129 {
130         TftpSession *fds = TFTP_CAST(fd);
131         uint8_t *_buf = (uint8_t *) buf;
132         size_t read_bytes = 0;
133         size_t offset = fds->valid_data - fds->bytes_available;
134
135         if (fds->pending_ack)
136         {
137                 ASSERT(fds->block == 0);
138                 struct ackframe ack;
139                 ack.opcode = TFTP_ACK;
140                 ack.block_num = fds->block;
141                 lwip_sendto(fds->sock, &ack, 4, 0, (struct sockaddr *)&fds->addr, fds->addr_len);
142                 fds->pending_ack = false;
143         }
144
145         if (fds->bytes_available < size)
146         {
147                 /* check if we were called again after an error */
148                 if (fds->bytes_available > 0)
149                 {
150                         memcpy(_buf, fds->frame.data + offset, fds->bytes_available);
151                         LOG_INFO("ba < size. Copied %zd bytes from offset %zd\n", fds->bytes_available, offset);
152                         /* adjust buf and size */
153                         _buf += fds->bytes_available;
154                         size -= fds->bytes_available;
155                         read_bytes += fds->bytes_available;
156                 }
157
158                 if (!fds->is_xfer_end)
159                 {
160                         LOG_INFO("Waiting for new TFTP packet\n");
161                         /* get more data, we can wait since the function is blocking */
162                         ssize_t rd = tftp_readPacket(fds, &fds->frame, fds->timeout);
163                         if (rd < 0)
164                         {
165                                 fds->bytes_available = 0;
166                                 fds->error = rd;
167                                 return 0;
168                         }
169                         else
170                         {
171                                 if (rd < TFTP_PACKET_SIZE)
172                                 {
173                                         fds->is_xfer_end = true;
174                                         LOG_INFO("Received the last packet\n");
175                                 }
176                                 fds->bytes_available = (size_t)rd - sizeof(struct TftpHeader);
177                                 fds->valid_data = fds->bytes_available;
178                                 offset = 0;
179                         }
180                 }
181                 else
182                 {
183                         LOG_INFO("Transfer finished\n");
184                         fds->bytes_available -= fds->bytes_available;
185                         fds->valid_data = 0;
186                         return read_bytes;
187                 }
188         }
189
190         /* check how many bytes we need to copy */
191         size_t res = MIN(fds->bytes_available, size);
192         LOG_INFO("Copying %zd bytes from offset %zd\n", res, offset);
193         memcpy(_buf, fds->frame.data + offset, res);
194         fds->bytes_available -= res;
195         read_bytes += res;
196         return read_bytes;
197 }
198
199 static int tftp_error(struct KFile *fd)
200 {
201         TftpSession *fds = TFTP_CAST(fd);
202         return fds->error;
203 }
204
205 static void tftp_clearerr(struct KFile *fd)
206 {
207         TftpSession *fds = TFTP_CAST(fd);
208         fds->error = 0;
209 }
210
211 static int tftp_close(struct KFile *fd)
212 {
213         TftpSession *fds = TFTP_CAST(fd);
214         struct errframe err;
215         if (fds->pending_ack)
216         {
217                 err.opcode = TFTP_PROTOERR;
218                 err.errcode = TFTP_PROTOERR_ACCESS_VIOLATION;
219                 err.str = '\0';
220                 lwip_sendto(fds->sock, &err, 5, 0, (struct sockaddr *)&fds->addr, fds->addr_len);
221                 LOG_INFO("Closed connection upon user request\n");
222         }
223         return 0;
224 }
225
226 static void resetTftpState(TftpSession *ctx)
227 {
228         ctx->block = 0;
229         ctx->error = 0;
230         ctx->bytes_available = 0;
231         ctx->valid_data = 0;
232         ctx->is_xfer_end = false;
233         ctx->pending_ack = false;
234 }
235
236 /**
237  * Listen for incoming tftp sessions.
238  *
239  * \note Only write requests are accepted.
240  *
241  * \param ctx Initialized TftpChannel
242  * \param filename String to be filled with file name to be written
243  * \param len Length of the filename
244  * \param mode Open mode for the returned KFile
245  * \return KFile pointer to read from
246  */
247 KFile *tftp_listen(TftpSession *ctx, char *filename, size_t len, TftpOpenMode *mode)
248 {
249         DECLARE_TIMEOUT(wait_tm, ctx->timeout);
250         resetTftpState(ctx);
251
252         int res = tftp_waitEvent(ctx, &wait_tm);
253         if (res == 0)
254         {
255                 ctx->error = TFTP_ERR_TIMEOUT;
256                 return NULL;
257         }
258         if (res == -1)
259         {
260                 ctx->error = TFTP_ERR;
261                 return NULL;
262         }
263
264         // listen onto TFTP port
265         ctx->addr_len = sizeof(ctx->addr);
266         ssize_t rd = 0;
267         if ((rd = lwip_recvfrom(ctx->sock, &ctx->frame, sizeof(Tftpframe), 0, (struct sockaddr *)&ctx->addr, &ctx->addr_len)) > 0)
268         {
269                 // check if the packet is WRQ, otherwise discard the packet
270                 if (ctx->frame.hdr.opcode == TFTP_WRQ)
271                 {
272                         *mode = TFTP_WRITE;
273                         ctx->pending_ack = true;
274                         strncpy(filename, (char *)&ctx->frame.hdr.th_u, len);
275                         filename[len - 1] = '\0';
276                         ctx->error = 0;
277                         return &ctx->kfile_request;
278                 }
279                 else
280                         *mode = TFTP_READ;
281         }
282         ctx->error = TFTP_ERR;
283         return NULL;
284 }
285
286 /**
287  * Init a server session
288  *
289  * Create a IPv4 session on all addresses and port \a port.
290  *
291  * \param ctx Context to be initialized as server
292  * \param port Port to listen incoming connections
293  * \param timeout Timeout to be used for tftp connections
294  * \return 0 if successful, -1 otherwise
295  */
296 int tftp_init(TftpSession *ctx, unsigned short port, mtime_t timeout)
297 {
298         DB(ctx->kfile_request._type = KFT_TFTPSESSION);
299         ctx->kfile_request.read = tftp_read;
300         ctx->kfile_request.error = tftp_error;
301         ctx->kfile_request.clearerr = tftp_clearerr;
302         ctx->kfile_request.close = tftp_close;
303         resetTftpState(ctx);
304
305         /* Unused kfile methods */
306         ctx->kfile_request.seek = NULL;
307         ctx->kfile_request.write = NULL;
308         ctx->kfile_request.flush = NULL;
309         ctx->kfile_request.reopen = NULL;
310
311         struct sockaddr_in sa;
312         sa.sin_family = AF_INET;
313         sa.sin_addr.s_addr = htonl(INADDR_ANY);
314         sa.sin_port = htons(port);
315         ctx->timeout = timeout;
316
317         ctx->sock = lwip_socket(AF_INET, SOCK_DGRAM, 0);
318         if (ctx->sock == -1)
319         {
320                 LOG_INFO("TFTP socket error\n");
321                 return -1;
322         }
323
324         if(lwip_bind(ctx->sock, (struct sockaddr *)&sa, sizeof(sa)))
325         {
326                 LOG_INFO("Error binding socket\n");
327                 return -1;
328         }
329         return 0;
330 }
331