removed test_gnu_dummy_s2k_extension(); no longer necessary
[monkeysphere.git] / src / keytrans / gnutls-helpers.c
1 /* Author: Daniel Kahn Gillmor <dkg@fifthhorseman.net> */
2 /* Date: Fri, 04 Apr 2008 19:31:16 -0400 */
3 /* License: GPL v3 or later */
4
5 #include "gnutls-helpers.h"
6 /* for htonl() */
7 #include <arpa/inet.h>
8
9 /* for setlocale() */
10 #include <locale.h>
11
12 /* for isalnum() */
13 #include <ctype.h>
14
15 /* for exit() */
16 #include <unistd.h>
17
18 #include <assert.h>
19
20 /* higher levels allow more frivolous error messages through. 
21    this is set with the MONKEYSPHERE_DEBUG variable */
22 static int loglevel = 0;
23
24 void err(int level, const char* fmt, ...) {
25   va_list ap;
26   if (level > loglevel)
27     return;
28   va_start(ap, fmt);
29   vfprintf(stderr, fmt, ap);
30   va_end(ap);
31   fflush(stderr);
32 }
33
34 void logfunc(int level, const char* string) {
35   fprintf(stderr, "GnuTLS Logging (%d): %s\n", level, string);
36 }
37
38 void init_keyid(gnutls_openpgp_keyid_t keyid) {
39   memset(keyid, 'x', sizeof(gnutls_openpgp_keyid_t));
40 }
41
42
43
44 void make_keyid_printable(printable_keyid out, gnutls_openpgp_keyid_t keyid)
45 {
46   assert(sizeof(out) >= 2*sizeof(keyid));
47   hex_print_data((char*)out, (const unsigned char*)keyid, sizeof(keyid));
48 }
49
50 /* you must have twice as many bytes in the out buffer as in the in buffer */
51 void hex_print_data(char* out, const unsigned char* in, size_t incount)
52 {
53   static const char hex[16] = "0123456789ABCDEF";
54   unsigned int inix = 0, outix = 0;
55   
56   while (inix < incount) {
57     out[outix] = hex[(in[inix] >> 4) & 0x0f];
58     out[outix + 1] = hex[in[inix] & 0x0f];
59     inix++;
60     outix += 2;
61   }
62 }
63
64 unsigned char hex2bin(unsigned char x) {
65   if ((x >= '0') && (x <= '9')) 
66     return x - '0';
67   if ((x >= 'A') && (x <= 'F')) 
68     return 10 + x - 'A';
69   if ((x >= 'a') && (x <= 'f')) 
70     return 10 + x - 'a';
71   return 0xff;
72 }
73
74 void collapse_printable_keyid(gnutls_openpgp_keyid_t out, printable_keyid in) {
75   unsigned int pkix = 0, outkix = 0;
76   while (pkix < sizeof(printable_keyid)) {
77     unsigned hi = hex2bin(in[pkix]);
78     unsigned lo = hex2bin(in[pkix + 1]);
79     if (hi == 0xff) {
80       err(0, "character '%c' is not a hex char\n", in[pkix]);
81       exit(1);
82     }
83     if (lo == 0xff) {
84       err(0, "character '%c' is not a hex char\n", in[pkix + 1]);
85       exit(1);
86     }
87     out[outkix] = lo | (hi << 4);
88
89     pkix += 2;
90     outkix++;
91   }
92 }
93
94 unsigned int hexstring2bin(unsigned char* out, const char* in) {
95   unsigned int pkix = 0, outkix = 0;
96   int hi = 0; /* which nybble is it? */
97   
98   while (in[pkix]) {
99     unsigned char z = hex2bin(in[pkix]);
100     if (z != 0xff) {
101       if (!hi) {
102         if (out) out[outkix] = (z << 4);
103         hi = 1;
104       } else {
105         if (out) out[outkix] |= z;
106         hi = 0;
107         outkix++;
108       }
109       pkix++;
110     }      
111   }
112   return outkix*8 + (hi ? 4 : 0);
113 }
114
115 int convert_string_to_keyid(gnutls_openpgp_keyid_t out, const char* str) {
116   printable_keyid p;
117   int ret;
118
119   ret = convert_string_to_printable_keyid(p, str);
120   if (ret == 0) 
121     collapse_printable_keyid(out, p);
122   return ret;
123 }
124 int convert_string_to_printable_keyid(printable_keyid pkeyid, const char* str) {
125   int arglen, x;
126   arglen = 0;
127   x = 0;
128   while ((arglen <= sizeof(printable_keyid)) &&
129          (str[x] != '\0')) {
130     if (isxdigit(str[x])) {
131       if (arglen == sizeof(printable_keyid)) {
132         err(0, "There are more than %d hex digits in the keyid '%s'\n", sizeof(printable_keyid), str);
133         return 1;
134       }
135       pkeyid[arglen] = str[x];
136       arglen++;
137     }
138     x++;
139   }
140   
141   if (arglen != sizeof(printable_keyid)) {
142     err(0, "Keyid '%s' is not %d hex digits in length\n", str, sizeof(printable_keyid));
143     return 1;
144   }
145   return 0;
146 }
147
148
149
150 int init_gnutls() {
151   const char* version = NULL;
152   const char* debug_string = NULL;
153   int ret;
154
155   if (debug_string = getenv("MONKEYSPHERE_DEBUG"), debug_string) {
156     loglevel = atoi(debug_string);
157   }
158
159   if (ret = gnutls_global_init(), ret) {
160     err(0, "Failed to do gnutls_global_init() (error: %d)\n", ret);
161     return 1;
162   }
163
164   version = gnutls_check_version(NULL);
165
166   if (version) 
167     err(1, "gnutls version: %s\n", version);
168   else {
169     err(0, "no gnutls version found!\n");
170     return 1;
171   }
172
173   gnutls_global_set_log_function(logfunc);
174   
175   gnutls_global_set_log_level(loglevel);
176   err(1, "set log level to %d\n", loglevel);
177
178   return 0;
179 }
180
181 void init_datum(gnutls_datum_t* d) {
182   d->data = NULL;
183   d->size = 0;
184 }
185 void copy_datum(gnutls_datum_t* dest, const gnutls_datum_t* src) {
186   dest->data = gnutls_realloc(dest->data, src->size);
187   dest->size = src->size;
188   memcpy(dest->data, src->data, src->size);
189 }
190 int compare_data(const gnutls_datum_t* a, const gnutls_datum_t* b) {
191   if (a->size > b->size) {
192     err(0,"a is larger\n");
193     return 1;
194   }
195   if (a->size < b->size) {
196     err(0,"b is larger\n");
197     return -1;
198   }
199   return memcmp(a->data, b->data, a->size);
200 }
201 void free_datum(gnutls_datum_t* d) {
202   gnutls_free(d->data);
203   d->data = NULL;
204   d->size = 0;
205 }
206
207 /* read the passed-in string, store in a single datum */
208 int set_datum_string(gnutls_datum_t* d, const char* s) {
209   unsigned int x = strlen(s)+1;
210   unsigned char* c = NULL;
211
212   c = gnutls_realloc(d->data, x);
213   if (NULL == c)
214     return -1;
215   d->data = c;
216   d->size = x;
217   memcpy(d->data, s, x);
218   return 0;
219 }
220
221 /* read the passed-in file descriptor until EOF, store in a single
222    datum */
223 int set_datum_fd(gnutls_datum_t* d, int fd) {
224   unsigned int bufsize = 1024;
225   unsigned int len = 0;
226
227   FILE* f = fdopen(fd, "r");
228   if (bufsize > d->size) {
229     bufsize = 1024;
230     d->data = gnutls_realloc(d->data, bufsize);
231     if (d->data == NULL) {
232       err(0,"out of memory!\n");
233       return -1;
234     }
235     d->size = bufsize;
236   } else {
237     bufsize = d->size;
238   }
239   f = fdopen(fd, "r");
240   if (NULL == f) {
241     err(0,"could not fdopen FD %d\n", fd);
242   }
243   clearerr(f);
244   while (!feof(f) && !ferror(f)) { 
245     if (len == bufsize) {
246       /* allocate more space by doubling: */
247       bufsize *= 2;
248       d->data = gnutls_realloc(d->data, bufsize);
249       if (d->data == NULL) {
250         err(0,"out of memory!\n"); 
251         return -1;
252       };
253       d->size = bufsize;
254     }
255     len += fread(d->data + len, 1, bufsize - len, f);
256     /*     err(0,"read %d bytes\n", len); */
257   }
258   if (ferror(f)) {
259     err(0,"Error reading from fd %d (error: %d) (error: %d '%s')\n", fd, ferror(f), errno, strerror(errno));
260     return -1;
261   }
262     
263   /* touch up buffer size to match reality: */
264   d->data = gnutls_realloc(d->data, len);
265   d->size = len;
266   return 0;
267 }
268
269 /* read the file indicated (by name) in the fname parameter.  store
270    its entire contents in a single datum. */
271 int set_datum_file(gnutls_datum_t* d, const char* fname) {
272   struct stat sbuf;
273   unsigned char* c = NULL;
274   FILE* file = NULL;
275   size_t x = 0;
276
277   if (0 != stat(fname, &sbuf)) {
278     err(0,"failed to stat '%s'\n", fname);
279     return -1;
280   }
281   
282   c = gnutls_realloc(d->data, sbuf.st_size);
283   if (NULL == c) {
284     err(0,"failed to allocate %d bytes for '%s'\n", sbuf.st_size, fname);
285     return -1;
286   }
287
288   d->data = c;
289   d->size = sbuf.st_size;
290   file = fopen(fname, "r");
291   if (NULL == file) {
292     err(0,"failed to open '%s' for reading\n",  fname);
293     return -1;
294   }
295
296   x = fread(d->data, d->size, 1, file);
297   if (x != 1) {
298     err(0,"tried to read %d bytes, read %d instead from '%s'\n", d->size, x, fname);
299     fclose(file);
300     return -1;
301   }
302   fclose(file);
303   return 0;
304 }
305
306 int write_datum_fd(int fd, const gnutls_datum_t* d) {
307   if (d->size != write(fd, d->data, d->size)) {
308     err(0,"failed to write body of datum.\n");
309     return -1;
310   }
311   return 0;
312 }
313
314
315 int write_datum_fd_with_length(int fd, const gnutls_datum_t* d) {
316   uint32_t len;
317   int looks_negative = (d->data[0] & 0x80);
318   unsigned char zero = 0;
319
320   /* if the first bit is 1, then the datum will appear negative in the
321      MPI encoding style used by OpenSSH.  In that case, we'll increase
322      the length by one, and dump out one more byte */
323
324   if (looks_negative) {
325     len = htonl(d->size + 1);
326   } else {
327     len = htonl(d->size);
328   }
329   if (write(fd, &len, sizeof(len)) != sizeof(len)) {
330     err(0,"failed to write size of datum.\n");
331     return -2;
332   }
333   if (looks_negative) {
334     if (write(fd, &zero, 1) != 1) {
335       err(0,"failed to write padding byte for MPI.\n");
336       return -2;
337     }
338   }
339   return write_datum_fd(fd, d);
340 }
341
342 int write_data_fd_with_length(int fd, const gnutls_datum_t** d, unsigned int num) {
343   unsigned int i;
344   int ret;
345
346   for (i = 0; i < num; i++)
347     if (ret = write_datum_fd_with_length(fd, d[i]), ret != 0)
348       return ret;
349
350   return 0;
351 }
352
353
354 int datum_from_string(gnutls_datum_t* d, const char* str) {
355   d->size = strlen(str);
356   d->data = gnutls_realloc(d->data, d->size);
357   if (d->data == 0)
358     return ENOMEM;
359   memcpy(d->data, str, d->size);
360   return 0;
361 }
362
363
364 int create_writing_pipe(pid_t* pid, const char* path, char* const argv[]) {
365   int p[2];
366   int ret;
367
368   if (pid == NULL) {
369     err(0,"bad pointer passed to create_writing_pipe()\n");
370     return -1;
371   }
372
373   if (ret = pipe(p), ret == -1) {
374     err(0,"failed to create a pipe (error: %d \"%s\")\n", errno, strerror(errno));
375     return -1;
376   }
377
378   *pid = fork();
379   if (*pid == -1) {
380     err(0,"Failed to fork (error: %d \"%s\")\n", errno, strerror(errno));
381     return -1;
382   }
383   if (*pid == 0) { /* this is the child */
384     close(p[1]); /* close unused write end */
385     
386     if (0 != dup2(p[0], 0)) { /* map the reading end into stdin */
387       err(0,"Failed to transfer reading file descriptor to stdin (error: %d \"%s\")\n", errno, strerror(errno));
388       exit(1);
389     }
390     execvp(path, argv);
391     err(0,"exec %s failed (error: %d \"%s\")\n", path, errno, strerror(errno));
392     /* close the open file descriptors */
393     close(p[0]);
394     close(0);
395
396     exit(1);
397   } else { /* this is the parent */
398     close(p[0]); /* close unused read end */
399     return p[1];
400   }
401 }
402
403 int validate_ssh_host_userid(const char* userid) {
404   char* oldlocale = setlocale(LC_ALL, "C");
405   
406   /* choke if userid does not match the expected format
407      ("ssh://fully.qualified.domain.name") */
408   if (strncmp("ssh://", userid, strlen("ssh://")) != 0) {
409     err(0,"The user ID should start with ssh:// for a host key\n");
410     goto fail;
411   }
412   /* so that isalnum will work properly */
413   userid += strlen("ssh://");
414   while (0 != (*userid)) {
415     if (!isalnum(*userid)) {
416       err(0,"label did not start with a letter or a digit! (%s)\n", userid);
417       goto fail;
418     }
419     userid++;
420     while (isalnum(*userid) || ('-' == (*userid)))
421       userid++;
422     if (('.' == (*userid)) || (0 == (*userid))) { /* clean end of label:
423                                                  check last char
424                                                  isalnum */
425       if (!isalnum(*(userid - 1))) {
426         err(0,"label did not end with a letter or a digit!\n");
427         goto fail;
428       }
429       if ('.' == (*userid)) /* advance to the start of the next label */
430         userid++;
431     } else {
432       err(0,"invalid character in domain name: %c\n", *userid);
433       goto fail;
434     }
435   }
436   /* ensure that the last character is valid: */
437   if (!isalnum(*(userid - 1))) {
438     err(0,"hostname did not end with a letter or a digit!\n");
439     goto fail;
440   }
441   /* FIXME: fqdn's can be unicode now, thanks to RFC 3490 -- how do we
442      make sure that we've got an OK string? */
443
444   return 0;
445
446  fail:
447   setlocale(LC_ALL, oldlocale);
448   return 1;
449 }
450
451 /* http://tools.ietf.org/html/rfc4880#section-5.5.2 */
452 size_t get_openpgp_mpi_size(gnutls_datum_t* d) {
453   return 2 + d->size;
454 }
455
456 int write_openpgp_mpi_to_fd(int fd, gnutls_datum_t* d) {
457   uint16_t x;
458
459   x = d->size * 8;
460   x = htons(x);
461   
462   write(fd, &x, sizeof(x));
463   write(fd, d->data, d->size);
464   
465   return 0;
466 }