Factor out length retrieval in ASN.1 encoder
authorGreg Hudson <ghudson@mit.edu>
Fri, 6 Jan 2012 21:10:42 +0000 (21:10 +0000)
committerGreg Hudson <ghudson@mit.edu>
Fri, 6 Jan 2012 21:10:42 +0000 (21:10 +0000)
git-svn-id: svn://anonsvn.mit.edu/krb5/trunk@25615 dc483132-0cff-0310-8789-dd5450dbe970

src/lib/krb5/asn.1/asn1_encode.c

index 41edc6e6da80393e218bc9a09d171d50e58973c6..5fc1efdfd35c31146d92ba508d6e637045a90998 100644 (file)
@@ -191,10 +191,10 @@ asn1_encode_bitstring(asn1buf *buf, unsigned int len, const void *val,
     (assert((PTRINFO)->loadptr != NULL), (PTRINFO)->loadptr(PTR))
 #endif
 
-static int
+static unsigned int
 get_nullterm_sequence_len(const void *valp, const struct atype_info *seq)
 {
-    int i;
+    unsigned int i;
     const struct atype_info *a;
     const struct ptr_info *ptr;
     const void *elt, *eltptr;
@@ -215,7 +215,7 @@ get_nullterm_sequence_len(const void *valp, const struct atype_info *seq)
     return i;
 }
 static asn1_error_code
-encode_sequence_of(asn1buf *buf, int seqlen, const void *val,
+encode_sequence_of(asn1buf *buf, unsigned int seqlen, const void *val,
                    const struct atype_info *eltinfo,
                    unsigned int *retlen);
 
@@ -225,7 +225,7 @@ encode_nullterm_sequence_of(asn1buf *buf, const void *val,
                             int can_be_empty,
                             unsigned int *retlen)
 {
-    int length = get_nullterm_sequence_len(val, type);
+    unsigned int length = get_nullterm_sequence_len(val, type);
     if (!can_be_empty && length == 0) return ASN1_MISSING_FIELD;
     return encode_sequence_of(buf, length, val, type, retlen);
 }
@@ -371,6 +371,39 @@ encode_type(asn1buf *buf, const void *val, const struct atype_info *a,
  * specified.  If omit_tag is non-NULL, omit the outer tag and return its
  * construction bit instead (only valid if the field has no context tag).
  */
+static asn1_error_code
+get_field_len(const void *val, const struct field_info *field,
+              unsigned int *retlen)
+{
+    const void *lenptr = (const char *)val + field->lenoff;
+
+    assert(field->lentype != NULL);
+    assert(field->lentype->type == atype_int ||
+           field->lentype->type == atype_uint);
+    assert(sizeof(int) <= sizeof(asn1_intmax));
+    assert(sizeof(unsigned int) <= sizeof(asn1_uintmax));
+    if (field->lentype->type == atype_int) {
+        const struct int_info *tinfo = field->lentype->tinfo;
+        asn1_intmax xlen = tinfo->loadint(lenptr);
+        if (xlen < 0)
+            return EINVAL;
+        if ((unsigned int)xlen != (asn1_uintmax)xlen)
+            return EINVAL;
+        if ((unsigned int)xlen > UINT_MAX)
+            return EINVAL;
+        *retlen = (unsigned int)xlen;
+    } else {
+        const struct uint_info *tinfo = field->lentype->tinfo;
+        asn1_uintmax xlen = tinfo->loaduint(lenptr);
+        if ((unsigned int)xlen != xlen)
+            return EINVAL;
+        if (xlen > UINT_MAX)
+            return EINVAL;
+        *retlen = (unsigned int)xlen;
+    }
+    return 0;
+}
+
 static asn1_error_code
 encode_a_field(asn1buf *buf, const void *val, const struct field_info *field,
                unsigned int *retlen, asn1_construction *omit_tag)
@@ -405,9 +438,8 @@ encode_a_field(asn1buf *buf, const void *val, const struct field_info *field,
     }
     case field_sequenceof_len:
     {
-        const void *dataptr, *lenptr;
-        int slen;
-        const struct atype_info *a;
+        const void *dataptr = (const char *)val + field->dataoff;
+        unsigned int slen;
         const struct ptr_info *ptrinfo;
 
         /*
@@ -415,38 +447,16 @@ encode_a_field(asn1buf *buf, const void *val, const struct field_info *field,
          * address we compute is a pointer-to-pointer, and that's what
          * field->atype must help us dereference.
          */
-        dataptr = (const char *)val + field->dataoff;
-        lenptr = (const char *)val + field->lenoff;
         assert(field->atype->type == atype_ptr);
         ptrinfo = field->atype->tinfo;
         dataptr = LOADPTR(dataptr, ptrinfo);
-        a = ptrinfo->basetype;
-        assert(field->lentype != 0);
-        assert(field->lentype->type == atype_int || field->lentype->type == atype_uint);
-        assert(sizeof(int) <= sizeof(asn1_intmax));
-        assert(sizeof(unsigned int) <= sizeof(asn1_uintmax));
-        if (field->lentype->type == atype_int) {
-            const struct int_info *tinfo = field->lentype->tinfo;
-            asn1_intmax xlen = tinfo->loadint(lenptr);
-            if (xlen < 0)
-                return EINVAL;
-            if ((unsigned int) xlen != (asn1_uintmax) xlen)
-                return EINVAL;
-            if ((unsigned int) xlen > INT_MAX)
-                return EINVAL;
-            slen = (int) xlen;
-        } else {
-            const struct uint_info *tinfo = field->lentype->tinfo;
-            asn1_uintmax xlen = tinfo->loaduint(lenptr);
-            if ((unsigned int) xlen != xlen)
-                return EINVAL;
-            if (xlen > INT_MAX)
-                return EINVAL;
-            slen = (int) xlen;
-        }
+        retval = get_field_len(val, field, &slen);
+        if (retval)
+            return retval;
         if (slen != 0 && dataptr == NULL)
             return ASN1_MISSING_FIELD;
-        retval = encode_sequence_of(buf, slen, dataptr, a, &length);
+        retval = encode_sequence_of(buf, slen, dataptr, ptrinfo->basetype,
+                                    &length);
         if (retval)
             return retval;
         construction = CONSTRUCTED;
@@ -469,49 +479,23 @@ encode_a_field(asn1buf *buf, const void *val, const struct field_info *field,
     }
     case field_string:
     {
-        const void *dataptr, *lenptr;
+        const void *dataptr = (const char *)val + field->dataoff;
         const struct atype_info *a;
-        size_t slen;
+        unsigned int slen;
         const struct string_info *string;
 
-        dataptr = (const char *)val + field->dataoff;
-        lenptr = (const char *)val + field->lenoff;
-
         a = field->atype;
         assert(a->type == atype_string || a->type == atype_opaque);
         assert(!(a->type == atype_opaque && field->tag_implicit));
-        assert(field->lentype != 0);
-        assert(field->lentype->type == atype_int || field->lentype->type == atype_uint);
-        assert(sizeof(int) <= sizeof(asn1_intmax));
-        assert(sizeof(unsigned int) <= sizeof(asn1_uintmax));
-        if (field->lentype->type == atype_int) {
-            const struct int_info *tinfo = field->lentype->tinfo;
-            asn1_intmax xlen = tinfo->loadint(lenptr);
-            if (xlen < 0)
-                return EINVAL;
-            if ((size_t) xlen != (asn1_uintmax) xlen)
-                return EINVAL;
-            slen = (size_t) xlen;
-        } else {
-            const struct uint_info *tinfo = field->lentype->tinfo;
-            asn1_uintmax xlen = tinfo->loaduint(lenptr);
-            if ((size_t) xlen != xlen)
-                return EINVAL;
-            slen = (size_t) xlen;
-        }
-
+        retval = get_field_len(val, field, &slen);
+        if (retval)
+            return retval;
         string = a->tinfo;
         dataptr = LOADPTR(dataptr, string);
-        if (slen == SIZE_MAX)
-            /* Error - negative or out of size_t range.  */
-            return EINVAL;
         if (dataptr == NULL && slen != 0)
             return ASN1_MISSING_FIELD;
-        /* Currently string encoders want "unsigned int" for length. */
-        if (slen != (unsigned int)slen)
-            return EINVAL;
         assert(string->enclen != NULL);
-        retval = string->enclen(buf, (unsigned int) slen, dataptr, &length);
+        retval = string->enclen(buf, slen, dataptr, &length);
         if (retval)
             return retval;
         if (a->type == atype_string)
@@ -599,21 +583,20 @@ just_encode_sequence(asn1buf *buf, const void *val,
 }
 
 static asn1_error_code
-encode_sequence_of(asn1buf *buf, int seqlen, const void *val,
+encode_sequence_of(asn1buf *buf, unsigned int seqlen, const void *val,
                    const struct atype_info *eltinfo,
                    unsigned int *retlen)
 {
     asn1_error_code retval;
-    unsigned int sum = 0;
-    int i;
+    unsigned int sum = 0, i;
 
-    for (i = seqlen-1; i >= 0; i--) {
+    for (i = seqlen; i > 0; i--) {
         const void *eltptr;
         unsigned int length;
         const struct atype_info *a = eltinfo;
 
         assert(eltinfo->size != 0);
-        eltptr = (const char *)val + i * eltinfo->size;
+        eltptr = (const char *)val + (i - 1) * eltinfo->size;
         retval = encode_type(buf, eltptr, a, &length, NULL);
         if (retval)
             return retval;