sec: properly handle unaligned memory in xor_block() and xor_block_const()
authorarighi <arighi@38d2e660-2303-0410-9eaa-f027e97ec537>
Thu, 17 Feb 2011 14:15:51 +0000 (14:15 +0000)
committerarighi <arighi@38d2e660-2303-0410-9eaa-f027e97ec537>
Thu, 17 Feb 2011 14:15:51 +0000 (14:15 +0000)
This also fixes the following build warnings:

  warning: cast increases required alignment of target type

git-svn-id: https://src.develer.com/svnoss/bertos/trunk@4712 38d2e660-2303-0410-9eaa-f027e97ec537

bertos/sec/util.h

index 189e4ed715858734fcac66a270cbd69bfd3713fb..0e62fad7b5fdb58176b3e8c8bc095424b040e5a8 100644 (file)
 void password2key(const char *pwd, size_t pwd_len,
                                  uint8_t *key, size_t key_len);
 
-/**
- * Perform a bitwise xor between \a in and \a inout, and store
- * the result into \a inout.
- */
-INLINE void xor_block(uint8_t *out, const uint8_t *in1, const uint8_t* in2, size_t len);
-
-/**
- * Perform a bitwise xor over \a inout with constant \a k.
- */
-INLINE void xor_block_const(uint8_t *out, const uint8_t *in, uint8_t k, size_t len);
-
+/* Check if a pointer is aligned to a certain power-of-2 size */
+INLINE bool __is_aligned(const void *addr, size_t size)
+{
+       return ((size_t)addr & (size - 1)) == 0;
+}
 
-// FIXME: provide non-32bit fallback
-// FIXME: proper ifdef conditional
-#if 1 // 32-bit optimized versions
+INLINE void xor_block_8(uint8_t *out,
+               const uint8_t *in1, const uint8_t *in2, size_t len)
+{
+       while (len--)
+                *out++ = *in1++ ^ *in2++;
+}
 
-// FIXME: this code is currently buggy because it ignores alignment issues.
-INLINE void xor_block(uint8_t *out, const uint8_t *in1, const uint8_t* in2, size_t len)
+INLINE void xor_block_const_8(uint8_t *out,
+                       const uint8_t *in, uint8_t k, size_t len)
 {
-       ASSERT(((size_t)in1 % 4) == 0);
-       ASSERT(((size_t)in2 % 4) == 0);
-       ASSERT(((size_t)out % 4) == 0);
+       while (len--)
+                *out++ = *in++ ^ k;
+}
 
-       const uint32_t *ibuf1 = (const uint32_t *)in1;
-       const uint32_t *ibuf2 = (const uint32_t *)in2;
-       uint32_t *obuf = (uint32_t *)out;
-       size_t rem = (len & 3);
+INLINE void xor_block_32(uint32_t *out, const uint32_t *in1,
+                               const uint32_t *in2, size_t len)
+{
+       size_t rem = (len & (sizeof(uint32_t) - 1));
 
-       len /= 4;
+       len /= sizeof(uint32_t);
        while (len--)
-               *obuf++ = *ibuf1++ ^ *ibuf2++;
-
-       in1 = (const uint8_t*)ibuf1;
-       in2 = (const uint8_t*)ibuf2;
-       out = (uint8_t*)obuf;
-       while (rem--)
                *out++ = *in1++ ^ *in2++;
+       xor_block_8((uint8_t *)out,
+               (const uint8_t *)in1, (const uint8_t *)in2, rem);
 }
 
-INLINE void xor_block_const(uint8_t *out, const uint8_t *in, uint8_t k, size_t len)
+INLINE void xor_block_const_32(uint32_t *out, const uint32_t *in,
+                                       uint8_t k, size_t len)
 {
-       ASSERT(((size_t)in % 4) == 0);
-       ASSERT(((size_t)out % 4) == 0);
-
-       uint32_t k32 = k | ((uint32_t)k<<8) | ((uint32_t)k<<16) | ((uint32_t)k<<24);
-       const uint32_t *ibuf = (const uint32_t *)in;
-       uint32_t *obuf = (uint32_t *)out;
-       size_t rem = (len & 3);
+       uint32_t k32 = k | ((uint32_t)k << 8) |
+                       ((uint32_t)k << 16) | ((uint32_t)k << 24);
+       size_t rem = (len & (sizeof(uint32_t) - 1));
 
-       len /= 4;
+       len /= sizeof(uint32_t);
        while (len--)
-               *obuf++ = *ibuf++ ^ k32;
+               *out++ = *in++ ^ k32;
+       xor_block_const_8((uint8_t *)out, (const uint8_t *)in, k, rem);
+}
 
-       in = (const uint8_t*)ibuf;
-       out = (uint8_t*)obuf;
-       while (rem--)
-               *out++ = *in++ ^ k;
+/**
+ * Perform a bitwise xor between \a in and \a inout, and store
+ * the result into \a inout.
+ */
+INLINE void xor_block(void *out, const void *in1, const void *in2, size_t len)
+{
+       if (__is_aligned(out, sizeof(uint32_t)) &&
+                       __is_aligned(in1, sizeof(uint32_t)) &&
+                       __is_aligned(in2, sizeof(uint32_t)))
+       {
+               uint32_t *obuf = (uint32_t *)((size_t)out);
+               const uint32_t *ibuf1 = (const uint32_t *)((size_t)in1);
+               const uint32_t *ibuf2 = (const uint32_t *)((size_t)in2);
+
+               xor_block_32(obuf, ibuf1, ibuf2, len);
+       }
+       else
+       {
+               uint8_t *obuf = (uint8_t *)((size_t)out);
+               const uint8_t *ibuf1 = (const uint8_t *)((size_t)in1);
+               const uint8_t *ibuf2 = (const uint8_t *)((size_t)in2);
+
+               xor_block_8(obuf, ibuf1, ibuf2, len);
+       }
 }
 
-#endif
+/**
+ * Perform a bitwise xor over \a inout with constant \a k.
+ */
+INLINE void xor_block_const(uint8_t *out, const uint8_t *in, uint8_t k, size_t len)
+{
+       if (__is_aligned(out, sizeof(uint32_t)) &&
+                       __is_aligned(in, sizeof(uint32_t)))
+       {
+               uint32_t *obuf = (uint32_t *)((size_t)out);
+               const uint32_t *ibuf = (const uint32_t *)((size_t)in);
+
+               xor_block_const_32(obuf, ibuf, k, len);
+       }
+       else
+       {
+               uint8_t *obuf = (uint8_t *)((size_t)out);
+               const uint8_t *ibuf = (const uint8_t *)((size_t)in);
+
+               xor_block_const_8(obuf, ibuf, k, len);
+       }
+}
 
 #endif /* SEC_UTIL_H */