//
//
// gf2util.c - misc routines for polynomials with binary coefficients
//
// warning: written by Scott Duplichan long ago, use at your own risk
//          sduplichan,yahoo.com
//
// 
#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string.h>
#include <ctype.h>
#include <basetsd.h>

//----------------------------------------------------------------------------
// compiler and platform dependencies
//

#define INTN long            // longest integer compiler handles
#define UINTN unsigned int INTN // unsigned int version
#define UINTN_BITS 32           // number of bits in INTN, UINTN

#if !defined (_MSC_VER)
#define __forceinline
#define max(a,b) (((a) > (b)) ? (a) : (b))
#endif

//----------------------------------------------------------------------------

typedef unsigned char  UINT8;
typedef unsigned short UINT16; 

#define MAXBITS 16384
#if !defined MAXBITS
#error must define MAXBITS
#endif

#define UINTN_COUNT ((MAXBITS + UINTN_BITS - 1) / UINTN_BITS)
#define UINT8_COUNT  ((MAXBITS +  8 - 1) / 8)
#define UINT16_COUNT ((MAXBITS + 16 - 1) / 16)
#define UINT32_COUNT ((MAXBITS + 32 - 1) / 32)
#define UINT64_COUNT ((MAXBITS + 64 - 1) / 64)

//----------------------------------------------------------------------------
// large integer structure
//
typedef union
   {
   UINTN  uintn    [UINTN_COUNT];
   UINT8  uint8    [UINT8_COUNT];
   UINT16 uint16   [UINT16_COUNT];
   UINT32 uint32   [UINT32_COUNT];
   }
INTEGER;

//----------------------------------------------------------------------------
// 
// logError - printf message to stdout and exit
//

static void logError (char *message,...)
   {
   va_list Marker;
   char    buffer [100];

   va_start (Marker, message);
   vsprintf (buffer, message, Marker);
   va_end (Marker);
   fprintf (stderr, "\n%s\n", buffer);
   exit (1);
   }

//----------------------------------------------------------------------------

static INTEGER IntegerZero = {0};
static INTEGER IntegerOne  = {1};
static INTEGER IntegerTwo  = {2};
static INTEGER IntegerTen  = {10};
static int displayBase = 2; // default to binary display

//----------------------------------------------------------------------------
//
// extractbit - return the value of a selected bit from a extended integer
//
static __forceinline unsigned int extractbit (INTEGER *data, unsigned int bitNumber)
   {
   return (unsigned) ((data->uintn [bitNumber / UINTN_BITS] >> (bitNumber % UINTN_BITS)) & 1);
   }

//----------------------------------------------------------------------------
//
// extractbits - return the value of a selected bits from a extended integer
//
static __forceinline unsigned int extractbits (INTEGER *data, unsigned int lsb, unsigned int msb)
   {
   unsigned int index, total = 0;
   for (index = lsb; index <= msb; index++)
      if (extractbit (data, index))
         total |= 1 << (index - lsb);
   return total;
   }

//----------------------------------------------------------------------------
//
// clearbit - clear a bit in a extended integer
//
static __forceinline void clearbit (INTEGER *data, unsigned int bitnumber)
   {
   if (bitnumber >= MAXBITS)
      logError ("need to recompile with more than %u MAXBITS\n", bitnumber);
   data->uintn [bitnumber / UINTN_BITS] &= ~((UINTN) 1 << (bitnumber % UINTN_BITS));
   }

//----------------------------------------------------------------------------
//
// setbit set a bit in an integer
//
static __forceinline void setbit (INTEGER *data, unsigned int bitnumber)
   {
   if (bitnumber >= MAXBITS)
      logError ("need to recompile with more than %u MAXBITS\n", bitnumber);
   data->uintn [bitnumber / UINTN_BITS] |= (UINTN) 1 << (bitnumber % UINTN_BITS);
   }

//----------------------------------------------------------------------------
//
//  highestSetBit - finds highest set bit in extended integer.
//
static __forceinline signed int highestSetBit (INTEGER *data)
   {
   signed int bitNumber;

   for (bitNumber = MAXBITS - 1; bitNumber >= 0; bitNumber--)
      if (extractbit (data, bitNumber)) break;

   return bitNumber;
   }

//----------------------------------------------------------------------------

static unsigned int populationCount (INTEGER *data)
   {
   unsigned int index, weight = 0;

   for (index = 0; index < MAXBITS; index++)
      weight += extractbit (data, index);
   return weight;
   }

//----------------------------------------------------------------------------
//
// shiftLeft - shift left extended integer
//
static __forceinline void shiftLeft (INTEGER *integer, unsigned int shiftCount)
   {
   unsigned int index;
   INTEGER result = IntegerZero;

   if (shiftCount >= MAXBITS) logError ("leftShift overflow");

   for (index = 0; index < MAXBITS - shiftCount; index++)
      if (extractbit (integer, index))
         setbit (&result, index + shiftCount);
   *integer = result;
   }

//----------------------------------------------------------------------------
//
// shiftRight - shift right extended integer
//
static __forceinline void shiftRight (INTEGER *integer, unsigned int shiftCount)
   {
   unsigned int index;
   INTEGER result = IntegerZero;

   if (shiftCount >= MAXBITS) logError ("rightShift overflow");

   for (index = 0; index < MAXBITS - shiftCount; index++)
      if (extractbit (integer, index + shiftCount))
         setbit (&result, index);
   *integer = result;
   }

//----------------------------------------------------------------------------
//
// rotateLeft - rotate left extended integer
//
static __forceinline void rotateLeft (INTEGER *source, INTEGER *dest, unsigned int columnCount, signed int rotateCount)
   {
   unsigned int columnIndex;
   INTEGER temp = {0};

   for (columnIndex = 0; columnIndex < columnCount; columnIndex++)
      if (extractbit (source, columnIndex))
         setbit (&temp, (columnCount + columnIndex + rotateCount) % columnCount);
   *dest = temp;
   }

//----------------------------------------------------------------------------
//
// rotateRight - rotate right extended integer
//
static __forceinline void rotateRight (INTEGER *source, INTEGER *dest, unsigned int columnCount, signed int rotateCount)
   {
   unsigned int columnIndex;
   INTEGER temp = {0};

   for (columnIndex = 0; columnIndex < columnCount; columnIndex++)
      if (extractbit (source, columnIndex))
         setbit (&temp, (columnCount + columnIndex + rotateCount) % columnCount);
   *dest = temp;
   }

//----------------------------------------------------------------------------
//
// addInteger - add extended integers
//

static void addInteger (INTEGER *value1, INTEGER *value2, INTEGER *result)
   {
   unsigned int index, carry = 0;
   INTEGER total;
   for (index = 0; index < UINT16_COUNT; index++)
      {
      total.uint32 [0] = (UINT32) value1->uint16 [index] + value2->uint16 [index] + carry;
      result->uint16 [index] = total.uint16 [0];
      carry = total.uint16 [1];
      }
   }

//----------------------------------------------------------------------------

static __forceinline INTEGER *xorInteger (INTEGER *data1, INTEGER *data2, INTEGER *result)
   {
   unsigned int index = UINTN_COUNT;

   while (index--)
      result->uintn [index] = data1->uintn [index] ^ data2->uintn [index];

   return result;
   }

//----------------------------------------------------------------------------

static __forceinline INTEGER *andInteger (INTEGER *data1, INTEGER *data2, INTEGER *result)
   {
   unsigned int index = UINTN_COUNT;

   while (index--)
      result->uintn [index] = data1->uintn [index] & data2->uintn [index];

   return result;
   }

//----------------------------------------------------------------------------

static __forceinline INTEGER *andNotInteger (INTEGER *data1, INTEGER *data2, INTEGER *result)
   {
   unsigned int index = UINTN_COUNT;

   while (index--)
      result->uintn [index] = data1->uintn [index] & ~data2->uintn [index];

   return result;
   }

//----------------------------------------------------------------------------
// 
// multiplyPolynomial - multiply binary polynomial, coefficients are kept in 
//                      extended integers
//
static void multiplyPolynomial (INTEGER *factor1, INTEGER *factor2, INTEGER *product)
   {
   signed int index, factor1Power = highestSetBit (factor1);
   INTEGER  result = {0}, factor2copy = *factor2;

   if (factor1Power < 0)
      {
      *product = IntegerZero;
      return;
      }

   for (index = 0; index <= factor1Power; index++)
      {
      if (extractbit (factor1, index))
         xorInteger (&result, &factor2copy, &result);
      shiftLeft (&factor2copy, 1);
      }
   *product = result;
   }

//----------------------------------------------------------------------------
// 
// multiplyInteger - multiply binary integers, coefficients are kept in 
//                      extended integers
//
static void multiplyInteger (INTEGER *factor1, INTEGER *factor2, INTEGER *product)
   {
   signed int index, factor1Power = highestSetBit (factor1);
   INTEGER  result = {0}, factor2copy = *factor2;

   if (factor1Power < 0)
      {
      *product = IntegerZero;
      return;
      }

   for (index = 0; index <= factor1Power; index++)
      {
      if (extractbit (factor1, index))
         addInteger (&result, &factor2copy, &result);
      shiftLeft (&factor2copy, 1);
      }
   *product = result;
   }

//----------------------------------------------------------------------------
// 
// dividePolynomial - divide binary polynomial, coefficients are kept in 
//                    extended integers
//
static void dividePolynomial (INTEGER *numerator, INTEGER *denominator, INTEGER *quotient)
   {
   signed int numeratorPower = highestSetBit (numerator);
   signed int denominatorPower = highestSetBit (denominator);
   signed int denominatorShift = numeratorPower - denominatorPower;
   signed int bitsRetired;

   *quotient = IntegerZero;
   if ((signed) denominatorShift < 0) return;

   if (denominatorPower == -1) logError ("division by zero\n");
   else if (denominatorPower == 0) // division by one
      {
      *quotient = *numerator;
      *numerator = IntegerZero;
      return;
      }

   shiftLeft (denominator, denominatorShift);

   for (;;)
      {
      setbit (quotient, denominatorShift);
      xorInteger (numerator, denominator, numerator);
      bitsRetired = numeratorPower - highestSetBit (numerator);
      denominatorShift -= bitsRetired;
      numeratorPower -= bitsRetired;
      if (numeratorPower < denominatorPower) break;
      shiftRight (denominator, bitsRetired);
      }
   }

//----------------------------------------------------------------------------
// 
// modularMultiplyPolynomial - modular multiply binary polynomial, coefficients
//                             are kept in extended integers
//
static __forceinline void modularMultiplyPolynomial (INTEGER *factor1, INTEGER *factor2, INTEGER *modulo, INTEGER *product)
   {
   signed int index, factor1Power = highestSetBit (factor1), polynomialDegree = highestSetBit (modulo);
   INTEGER  result = {0}, factor2copy = *factor2;


   for (index = 0; index <= factor1Power; index++)
      {
      if (extractbit (factor1, index))
         xorInteger (&result, &factor2copy, &result);

      shiftLeft (&factor2copy, 1);
      if (extractbit (&factor2copy, polynomialDegree))
         xorInteger (&factor2copy, modulo, &factor2copy);
      }
   *product = result;
   }

//----------------------------------------------------------------------------
// 
// modularPowerPolynomial - modular exponentiation for binary polynomial,
//                          coefficients are kept in extended integers
//
static void modularPowerPolynomial (INTEGER *primitiveElement, INTEGER *power, INTEGER *modulo, INTEGER *result)
   {
   signed int bitNumber, totalBits = highestSetBit (power) - 1;

   if (memcmp (power, &IntegerZero, sizeof (INTEGER)) == 0)
      {
      *result = IntegerOne;
      return;
      }

   *result = *primitiveElement;
   for(bitNumber = totalBits; bitNumber >= 0; bitNumber--)
      {
      modularMultiplyPolynomial (result, result, modulo, result);
      if (extractbit (power, bitNumber))
         modularMultiplyPolynomial (result, primitiveElement, modulo, result);
      }
   }

//----------------------------------------------------------------------------
//
// multiplyByTen - multiply extended integer by 10
//

static void multiplyByTen (INTEGER *value)
   {
   INTEGER x8 = *value, x2 = *value;

   shiftLeft (&x8, 3);
   shiftLeft (&x2, 1);
   addInteger (&x8, &x2, value);
   if (value->uint8 [UINT8_COUNT - 1]) logError ("overflow is approaching");
   }

//----------------------------------------------------------------------------
// 
// subtractInteger - subtract extended integers
//
static void subtractInteger (INTEGER *value1, INTEGER *value2, INTEGER *result)
   {
   unsigned int index, borrow = 0;
   INTEGER total;
   for (index = 0; index < UINT16_COUNT; index++)
      {
      total.uint32 [0] = (UINT32) value1->uint16 [index] - value2->uint16 [index] - borrow;
      result->uint16 [index] = total.uint16 [0];
      borrow = total.uint16 [1] & 1;
      }
   }

//----------------------------------------------------------------------------
// 
// compareInteger - compare extended integers
//
static int compareInteger (INTEGER *value1, INTEGER *value2)
   {
   INTEGER temp;

   if (memcmp (value1, value2, sizeof (INTEGER)) == 0) return 0;
   subtractInteger (value1, value2, &temp);
   return 1 - 2 * extractbit (&temp, MAXBITS - 1);
   }

//----------------------------------------------------------------------------
// 
// divideInteger - divide extended integers
//
static void divideInteger (INTEGER *numerator, INTEGER *denominator, INTEGER *quotient)
   {
   signed int numeratorPower = highestSetBit (numerator);
   signed int denominatorPower = highestSetBit (denominator);
   signed int denominatorShift = numeratorPower - denominatorPower;
   signed int difference;

   if (denominatorPower == -1)
      logError ("division by zero\n");
   else if (denominatorPower == 0) // division by one
      {
      *quotient = *numerator;
      *numerator = IntegerZero;
      return;
      }
   *quotient = IntegerZero;
   difference = compareInteger (numerator, denominator);
   if (difference < 0) return;
   else if (difference == 0)
      {
      *quotient = IntegerOne;
      *numerator = IntegerZero;
      return;
      }

   shiftLeft (denominator, denominatorShift);

   for (;;)
      {
      difference = compareInteger (numerator, denominator);
      if (difference < 0)
         {
         denominatorShift--;
         if (denominatorShift < 0) return;
         shiftRight (denominator, 1);
         continue;
         }
      setbit (quotient, denominatorShift);
      subtractInteger (numerator, denominator, numerator);
      }
   }

//----------------------------------------------------------------------------

static INTEGER *allOnes (unsigned bits)
   {
   unsigned int index;
   static INTEGER retval;

   retval = IntegerZero;
   for (index = 0; index < bits; index++)
      setbit (&retval, index);
   return &retval;
   }

//----------------------------------------------------------------------------
//
// decimalAscii - return in decimal ascii representation of extended integer
//                in a static buffer. Four calls can be made before overwriting.
//
static char *decimalAscii (INTEGER *data)
   {
   static  char buffer [4] [MAXBITS / 3 + 2];
   static  unsigned int cycle;
   char    *position = buffer [cycle];
   char    *result = position + sizeof (buffer [0]) - 1;
   INTEGER temp, quotient;

   *result = '\0';
   *--result = 'd';
   if (memcmp (data, &IntegerZero, sizeof (INTEGER)) == 0)
      {
      result--;
      *result = '0';
      return result;
      }

   while (memcmp (data, &IntegerZero, sizeof (INTEGER)) != 0)
      {
      temp = *data;
      divideInteger (&temp, &IntegerTen, &quotient);
      result--;
      *result = temp.uint8 [0] + '0';
      *data = quotient;
      }

   if (++cycle == 4) cycle = 0;
   return result;
   }

//----------------------------------------------------------------------------
//
// binaryAscii - return in binary ascii representation of extended integer
//               in a static buffer. Four calls can be made before overwriting.
//
static char *binaryAscii (INTEGER *data, unsigned int bits)
   {
   static char buffer [4] [MAXBITS + 2];
   static unsigned int cycle;
   char *position = buffer [cycle];
   char *result = position;

   if (bits == 0) bits = highestSetBit (data) + 1;
   if (bits == 0) bits++;

   while (bits)
      *position++ = (char) ('0' + extractbit (data, --bits));

   *position = '\0';
   if (++cycle == 4) cycle = 0;
   return result;
   }

//----------------------------------------------------------------------------
//
// hexAscii - return in hexadecimal ascii representation of extended integer
//            in a static buffer. Four calls can be made before overwriting.
//

static char *hexAscii (INTEGER *data, unsigned int bits)
   {
   static char buffer [4] [MAXBITS * 4 + 2];
   static unsigned int cycle;
   unsigned int index;

   char *position = buffer [cycle];
   char *result = position;

   if (bits == 0) bits = highestSetBit (data) + 1;
   if (bits == 0) bits++;

   index = (bits + 7) / 8;
   position += sprintf (position, "0x");
   while (index--)
      position += sprintf (position, "%02X", data->uint8 [index]);
   *position = '\0';
   if (++cycle == 4) cycle = 0;
   return result;
   }

//----------------------------------------------------------------------------
//
// octalAscii - return in octal ascii representation of extended integer
//              in a static buffer. Four calls can be made before overwriting.
//

static char *octalAscii (INTEGER *data, unsigned int bits)
   {
   static char buffer [4] [MAXBITS / 3 + 2];
   static unsigned int cycle;
   char *position = buffer [cycle];
   char *result = position;

   if (bits == 0) bits = highestSetBit (data) + 1;
   bits = (bits + 2) / 3 * 3;

   if (bits)
      while (bits)
         {
         *position = (char) ('0' + extractbits (data, bits - 3, bits - 1));
         position++;
         bits -= 3;
         }
   else *position++ = '0';

   *position++ = 'o';
   *position = '\0';
   if (++cycle == 4) cycle = 0;
   return result;
   }

//----------------------------------------------------------------------------
//
// polynomialText - return text string containing ascii description of
//                  the polynomial represented by the extended integer 
//

static char *polynomialAscii (INTEGER *polynomial)
   {
   static char text [5000];
   unsigned int    index, plusFlag = 0;
   char        *position = text;

   for (index = MAXBITS - 1; index > 0; index--)
      if (extractbit (polynomial, index))
         {
         if (plusFlag) position += sprintf (position, " + ");
         position += sprintf (position, "x");
         if (index > 1) position += sprintf (position, "^%u", index);
         plusFlag = 1;
         }

      if (extractbit (polynomial, 0))
         {
         if (plusFlag) position += sprintf (position, " + ");
         sprintf (position, "1");
         }
      return text;
   }

//----------------------------------------------------------------------------
//
// integerToAscii - return in ascii representation of extended integer
//                  in a static buffer, farmatted according to displayBase
//

static char *integerToAscii (INTEGER *data, unsigned int bits)
   {
   if (displayBase == 2)
      return binaryAscii (data, bits);
   else if (displayBase == 8)
      return octalAscii (data, bits);
   else if (displayBase == 10)
      return decimalAscii (data);
   else if (displayBase == 16)
      return hexAscii (data, bits);
   else
      return polynomialAscii (data);
   }

//----------------------------------------------------------------------------
// 
// findBase - looks at ascii buffer and determines the base. A leading 0x forces hex.
//            A trailing b, o, d forces binary, octal, or decimal, respectively.
//            If no prefix or suffix is present, a best guess is made by scanning the digits.
//
static unsigned int findBase (char *position, int defaultBase)
   {
   int      base = 0, maxCharacter = '0';
   char     *endOfNumber;

   if (position [0] == '0' && tolower (position [1]) == 'x') return 16;

   endOfNumber = position + strspn (position, "0123456789");
   if (tolower (*endOfNumber) == 'b') base = 2;
   if (tolower (*endOfNumber) == 'o') base = 8;
   if (tolower (*endOfNumber) == 'd') base = 10;

   if (base)
      {
      *endOfNumber = '\0';
      return base;
      }

   // no prefix or suffix, look at digits
   while (position != endOfNumber)
      {
      if (maxCharacter < *position)
         maxCharacter = *position;
      position++;
      }

   if (maxCharacter <= '1') base = 2; 
   else if (maxCharacter <= '7') base = 8; 
   else if (maxCharacter <= '9') base = 10;
   return max (base, defaultBase);
   }

//----------------------------------------------------------------------------
//
// scanDecimalDigits - read a extended integer from an ascii buffer of decimal digits.
//
static void scanDecimalDigits (char *buffer, INTEGER *value)
   {
   char     *position = buffer;
   INTEGER  dvalue = IntegerZero;

   *value = IntegerZero;

   for (;;)
      {
      unsigned int digit = *position++;

      if (!isdigit (digit)) break;
      multiplyByTen (value);
      dvalue.uintn [0] = digit - '0';
      addInteger (value, &dvalue, value);
      }
   }

//----------------------------------------------------------------------------
//
// scanDigits - read a extended integer from an ascii buffer on digits of a selectable base.
//
static void scanDigits (char *buffer, INTEGER *integer, unsigned int defaultBase)
   {
   static int digitBits [] = {0, 0, 1, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 4};
   char *position, *endOfNumber;
   unsigned int bitNumber, index, base, bitsPerDigit;
   char *tmpBuffer = strdup (buffer);
   char *start = tmpBuffer;

   base = findBase (start, defaultBase);
   if (base == 10)
      {
      scanDecimalDigits (start, integer);
      free (tmpBuffer);
      return;
      }

   if (start [0] == '0' && tolower (start [1]) == 'x') start += 2;
   if (base == 16)
      endOfNumber = start + strspn (start, "0123456789abcdefABCDEF");
   else if (base == 8)
      endOfNumber = start + strspn (start, "01234567");
   else
      endOfNumber = start + strspn (start, "01");

   if (*endOfNumber != ',' && *endOfNumber != '\0' && *endOfNumber != '\n' && *endOfNumber != ' ')
      logError ("invalid base %u digit: '%c'", base, *endOfNumber);
   bitsPerDigit = digitBits [base];
   position = endOfNumber - 1;
   bitNumber = 0;
   *integer = IntegerZero;

   while (position >= start)
      {
      unsigned int value;
      value = toupper (*position);
      if (value >= 'A') value = 10 + (value - 'A');
      else value -= '0';
      if (value >= base) logError ("invalid base %u digit (%c)", base, *position);
      for (index = 0; index < bitsPerDigit; index++)
         {
         if (value & 1)
            setbit (integer, bitNumber);
         bitNumber++;
         value >>= 1;
         }
      position--;
      }
   free (tmpBuffer);
   return;
   }

//----------------------------------------------------------------------------

static unsigned int isNumber (char *position)
   {
   char *endOfNumber;

   if (memcmp (position, "0x", 2) == 0)
      {
      position += 2;
      endOfNumber = position + strspn (position, "0123456789abcdefABCDEF");
      if (*endOfNumber == '\0') return 1;
      }
   if (isdigit (*position))
      {
      endOfNumber = position + strspn (position, "0123456789");
      if (endOfNumber [0] == '\0') return 1;
      if (endOfNumber [1] != '\0') return 0;
      if (tolower (*endOfNumber) == 'o') return 1;
      if (tolower (*endOfNumber) == 'b') return 1;
      if (tolower (*endOfNumber) == 'd') return 1;
      }
   return 0;
   }

//----------------------------------------------------------------------------

static char *skipComma (char *position)
   {
   position = strchr (position, ',');
   if (!position) logError ("expected comma");
   return position + 1;
   }

//----------------------------------------------------------------------------

static void nextRightGaloisLfsr (INTEGER *state, INTEGER *mask)
   {

   unsigned int lsb = extractbit (state, 0);
   shiftRight (state, 1);
   if (lsb) xorInteger (state, mask, state);
   }

//----------------------------------------------------------------------------

static void nextLeftGaloisLfsr (INTEGER *state, INTEGER *mask, unsigned int bits)
   {

   unsigned int msb = extractbit (state, bits - 1);
   shiftLeft (state, 1);
   if (msb) xorInteger (state, mask, state);

   // limit non-zero bits to number of registers in LFSR
   clearbit (state, bits);
   }

//----------------------------------------------------------------------------

static void nextLeftFibonacciLfsr (INTEGER *state, INTEGER *mask, unsigned int bits)
   {
   INTEGER temp;
   unsigned int newLsb;

   andInteger (mask, state, &temp);
   newLsb = populationCount (&temp) & 1;
   shiftLeft (state, 1);
   clearbit (state, bits);
   if (newLsb) setbit (state, 0);

   // limit non-zero bits to number of registers in LFSR
   clearbit (state, bits);
   }

//----------------------------------------------------------------------------

static void nextRightFibonacciLfsr (INTEGER *state, INTEGER *mask, unsigned int bits)
   {
   INTEGER temp;
   unsigned int newMsb;

   andInteger (mask, state, &temp);
   newMsb = populationCount (&temp) & 1;
   shiftRight (state, 1);
   if (newMsb) setbit (state, bits - 1);
   }

//----------------------------------------------------------------------------

static char *skipWhiteSpace (char *position)
   {
   while (*position == ' ' || *position == '\t') position++;
   return position;
   }

//----------------------------------------------------------------------------

static void nextValue (INTEGER *current, INTEGER *generator, unsigned int maxBits)
   {
   INTEGER result = IntegerZero;
   unsigned int index;

   for (index = 0; index < maxBits; index++)
      if (extractbit (current, index))
         xorInteger (&result, &generator [index], &result);
   *current = result;
   }

//----------------------------------------------------------------------------

static unsigned leftOrRight (char *position)
   {
   if (*position == 'r') return *position;
   if (*position == 'l') return *position;
   logError ("expected l or r");
   return 0;
   }

//----------------------------------------------------------------------------

static void help (void)
   {
   printf ("large integer input, output, and conversions:\n");
   printf ("\n");
   printf ("   gf2util display=hex 101   // default polynomial input is smallest valid base\n");
   printf ("   0x05\n");
   printf ("\n");
   printf ("   gf2util display=hex 101d  // use d suffix to force decimal\n");
   printf ("   0x65\n");
   printf ("\n");
   printf ("   gf2util display=hex 101o  // use o suffix to force octal\n");
   printf ("   0x41\n");
   printf ("\n");
   printf ("   gf2util display=hex 0x101 // use C notation for hex input\n");
   printf ("   0x101\n");
   printf ("\n");
   printf ("   gf2util display=poly 101  // example of polynomial display mode\n");
   printf ("   x^2 + 1\n");
   printf ("\n");
   printf ("\n");
   printf ("operations on polynomials with GF2 (binary) coefficients kept in large integers:\n");
   printf ("\n");
   printf ("   gf2util xor,1111100000,1111111111 // with binary coefficients, add and\n");
   printf ("   0000011111                        // subtract both map to XOR\n");
   printf ("\n");
   printf ("   gf2util multiply,101,111001       // not the same as ordinary multiply,\n");
   printf ("   11011101                          //   111001  because sub-totals are\n");
   printf ("                                     // 111001    combined with xor, not add\n");
   printf ("                                     // --------\n");
   printf ("                                     // 11011101\n");
   printf ("\n");
   printf ("   gf2util divide,11011101,111001    // like multiply, subtract is\n");
   printf ("   quotient 101 remainder 00000      // replaced by xor\n");
   printf ("\n");
   printf ("   gf2util factor,1101110101         // GF2 factor is consistent with\n");
   printf ("   111                               // GF2 multiply and divide\n");
   printf ("   10100111\n");
   printf ("\n");
   printf ("   gf2util power,101,3               // short for multiply,101,101,101\n");
   printf ("   1010101\n");
   printf ("\n");
   printf ("   gf2util modmul,11001,1011,1111    // modulo multiply, short for:\n");
   printf ("   1101                              // multiply 1011,1111\n");
   printf ("                                     // divide resuly by 11001, keep remainder\n");
   printf ("\n");
   printf ("   gf2util modpower,10011,10,3       // modulo power, short for:\n");
   printf ("   1000                              // modmul,11001,10,10,10\n");
   printf ("\n");
   printf ("   gf2util lfsr,g,l,10011,1          // Galois LFSR, same as shifting left\n");
   printf ("   0001                              // the previous state then XORing with\n");
   printf ("   0010                              // 0011 if the left shift shifted out a\n");
   printf ("   0100                              // one. Note: the XOR mask value 0011 is\n");
   printf ("   1000                              // is the value 10011 above, with ms bit\n");
   printf ("   0011                              // cleared. Also notice this is the same\n");
   printf ("   0110                              // sequence produced by gf2util modmul,\n");
   printf ("   1100                              // 10011,10,1 (,2 ,3 ,4 etc. For either\n");
   printf ("   1011                              // of these to make a maximal length\n");
   printf ("   0101                              // sequence, the first polynomial, 10011\n");
   printf ("   1010                              // in this example, must be a primitive\n");
   printf ("   0111                              // polynomial. The Galois LFSR is really\n");
   printf ("   1110                              // just an optimization of the above\n");
   printf ("   1111                              // successive modpower sequence.\n");
   printf ("   1101\n");
   printf ("   1001\n");
   printf ("   0001\n");
   printf ("\n");
   printf ("   gf2util lfsr,g,r,10011,1,14d      // demonstrates how to advance the above\n");
   printf ("   0010                              // 14 times by looping\n");
   printf ("\n");
   printf ("   gf2util modpower,10011,1001,14d   // demonstrates how to advance the above\n");
   printf ("   0010                              // 14 times by direct calculation\n");
   }

//----------------------------------------------------------------------------

INTEGER *lfsrReference;
unsigned referenceCount = 1, unique = 1;


int main (unsigned argc, char *argv [])
   {
   unsigned int index;

   if (argc == 1)
      {
      printf ("use options:\n");
      printf ("    display=...       binary,octal,decimal,hex,poly (result display mode)\n");  
      printf ("    multiply,p,p,...  multiply 2 or more polynomials (separate with commas)\n");  
      printf ("    divide,p,p,...    divide numerator polynomial by divisor polynomial(s)\n");  
      printf ("    shiftleft p,n     shift left polynomial p by n bits\n");  
      printf ("    shiftright p,n    shift right polynomial p by n bits\n");  
      printf ("    xor p1,p2         xor polynomials p1 and p2\n");  
      printf ("    power,p,n         raise polynomial p to n power\n");
      printf ("    modmul,p1,p2,...  mod p1, multiply p2,...\n");
      printf ("    modpower,p1,p2,n  mod p1, raise polynomial p2 to n power\n");
      printf ("    gf,p1,p2          dump GF elements, p1=primitive poly, p2=primitive element\n");
      printf ("    gfmisc,p1         find relationships between field representations, p1 is primitive\n");
      printf ("    lfsr,g,r,p1,p2,n  advance Galois LFSR state n times, r=right, p1=poly, p2=initial state\n");
      printf ("    lfsr,f,l,p1,p2,n  advance Fibonacci LFSR state n times, l=left, p1=poly, p2=initial state\n");
      printf ("    help              show some examples\n");
      printf ("\ninput n is decimal, input p is a polynomial in one of these forms:\n");
      printf ("    0x.... for hex,\n");
      printf ("    .....b for binary,\n");
      printf ("    .....o for octal,\n");
      printf ("    .....d for decimal\n");
      exit (1);
      }

   for (index = 1; index < argc; index++)
      {
      char *position = argv [index];

      if (memcmp (position, "display=", 8) == 0)
         {
         position += 8;
         if (strcmp (position, "binary") == 0) displayBase = 2;
         else if (strcmp (position, "octal") == 0) displayBase = 8;
         else if (strcmp (position, "decimal") == 0) displayBase = 10;
         else if (strcmp (position, "hex") == 0) displayBase = 16;
         else if (strcmp (position, "poly") == 0) displayBase = 0;
         else logError ("unknown command line argument: %s", argv [index]);
         }
      else if (memcmp (position, "divide", 6) == 0) // divide,numerator,divisor
         {
         INTEGER numerator, denominator, quotient;

         position = skipComma (position);
         scanDigits (position, &numerator, 2);

         for (;;)
            {
            unsigned int denominatorBits;
            position = strchr (position, ',');
            if (!position) break;
            position++;
            scanDigits (position, &denominator, 2);
            denominatorBits = highestSetBit (&denominator);
            dividePolynomial (&numerator, &denominator, &quotient);
            printf ("quotient %s remainder %s\n", integerToAscii (&quotient, 0), integerToAscii (&numerator, denominatorBits));
            if (memcmp (&quotient, &IntegerZero, sizeof (INTEGER)) == 0) break;
            numerator = quotient;
            quotient = IntegerZero;
            }
         }
      else if (memcmp (position, "multiply", 8) == 0) // multiply,arg1,arg2[,arg3...]
         {
         INTEGER factor, total;

         position = skipComma (position);
         scanDigits (position, &total, 2);
         for (;;)
            {
            position = strchr (position, ',');
            if (!position) break;
            position++;
            scanDigits (position, &factor, 2);
            multiplyPolynomial (&total, &factor, &total);
            }
         printf ("%s\n", integerToAscii (&total, 0));
         }
      else if (memcmp (position, "modmul", 6) == 0) // modmul,mod,arg1,arg2[,arg3...]
         {
         INTEGER factor, total = IntegerOne, modulo, remainder;
         int bits;

         position = skipComma (position);
         scanDigits (position, &modulo, 2);
         bits = highestSetBit (&modulo);
         for (;;)
            {
            position = strchr (position, ',');
            if (!position) break;
            position++;
            scanDigits (position, &factor, 2);
            multiplyPolynomial (&total, &factor, &total);
            }
         remainder = total;
         dividePolynomial (&remainder, &modulo, &total);
         printf ("%s\n", integerToAscii (&remainder, bits));
         }
      else if (memcmp (position, "shiftleft", 9) == 0) // shiftleft,arg,count
         {
         INTEGER integer, shiftCount;

         position = skipComma (position);
         scanDigits (position, &integer, 2);
         position = skipComma (position);
         scanDigits (position, &shiftCount, 10);
         shiftLeft (&integer, shiftCount.uintn [0]);
         printf ("%s\n", integerToAscii (&integer, 0));
         }
      else if (memcmp (position, "shiftright", 10) == 0) // shiftright,arg,count
         {
         INTEGER integer, shiftCount;

         position = skipComma (position);
         scanDigits (position, &integer, 2);
         position = skipComma (position);
         scanDigits (position, &shiftCount, 10);
         shiftRight (&integer, shiftCount.uintn [0]);
         printf ("%s\n", integerToAscii (&integer, 0));
         }
      else if (memcmp (position, "xor", 3) == 0) // xor,p1,p2
         {
         INTEGER p1, p2, result;

         position = skipComma (position);
         scanDigits (position, &p1, 2);
         position = skipComma (position);
         scanDigits (position, &p2, 2);
         xorInteger (&p1, &p2, &result);
         printf ("%s\n", integerToAscii (&result, max (highestSetBit (&p1), highestSetBit (&p2)) + 1));
         }
      else if (memcmp (position, "modpower", 8) == 0) // modpower,mod,element,power
         {
         INTEGER alpha, result, modulo, power;

         position = skipComma (position);
         scanDigits (position, &modulo, 2);
         position = skipComma (position);
         scanDigits (position, &alpha, 2);
         position = skipComma (position);
         scanDigits (position, &power, 10);
         modularPowerPolynomial (&alpha, &power, &modulo, &result);
         printf ("%s\n", integerToAscii (&result, highestSetBit (&modulo)));
         }
      else if (memcmp (position, "power", 5) == 0) // power,alpha,n
         {
         INTEGER total = IntegerOne, alpha, exponent;

         position = skipComma (position);
         scanDigits (position, &alpha, 2);
         position = skipComma (position);
         scanDigits (position, &exponent, 10);
         while (exponent.uintn [0]--)
            multiplyPolynomial (&total, &alpha, &total);
         printf ("%s\n", integerToAscii (&total, 0));
         }
      else if (memcmp (position, "lfsr,g", 6) == 0) // lfsr,g,d,state,polynomial [,advanceCount]
         {
         INTEGER primitivePolynomial, state, start, mask;
         unsigned int index, elements, direction;
         signed int degree;

         position = skipComma (position + 6);
         direction = leftOrRight (position);
         position = skipComma (position);
         scanDigits (position, &primitivePolynomial, 2);
         position = skipComma (position);
         scanDigits (position, &state, 2);
         degree = highestSetBit (&primitivePolynomial);
         if (highestSetBit (&state) + 1 > degree) logError ("initial state is too big");
         mask = primitivePolynomial;
         if (direction == 'r')
            shiftRight (&mask, 1);
         else
            clearbit (&mask, degree);

         position = strchr (position, ',');
         if (position)
            {
            INTEGER advanceCount;
            scanDigits (position + 1, &advanceCount, 10);
            if (direction == 'r')
               for (index = 0; index < advanceCount.uintn [0]; index++)
                  nextRightGaloisLfsr (&state, &mask);
            else
               for (index = 0; index < advanceCount.uintn [0]; index++)
                  nextLeftGaloisLfsr (&state, &mask, degree);
            printf ("%s\n", integerToAscii (&state, degree));
            }
         else // no count argument, dump entire sequence
            {
            start = state;
            elements = 1 << degree;
            // limit to partial list if the native integer overflows
            if (degree >= sizeof (elements) * 8) elements = ~0;
            if (direction == 'r')
               for (index = 0; index < elements; index++)
                  {
                  printf ("%s\n", integerToAscii (&state, degree));
                  nextRightGaloisLfsr (&state, &mask);
                  if (memcmp (&state, &start, sizeof state) == 0) break;
                  }
            else
               for (index = 0; index < elements; index++)
                  {
                  printf ("%s\n", integerToAscii (&state, degree));
                  nextLeftGaloisLfsr (&state, &mask, degree);
                  if (memcmp (&state, &start, sizeof state) == 0) break;
                  }
            }
         }

      else if (memcmp (position, "lfsr,f", 6) == 0) // lfsr,f,d,state,polynomial [,advanceCount]
         {
         INTEGER    primitivePolynomial, state, start, mask;
         unsigned int   index, elements, direction;
         signed int degree;

         position = skipComma (position + 6);
         direction = leftOrRight (position);
         position = skipComma (position);
         scanDigits (position, &primitivePolynomial, 2);
         position = skipComma (position);
         scanDigits (position, &state, 2);
         degree = highestSetBit (&primitivePolynomial);
         if (highestSetBit (&state) + 1 > degree) logError ("initial state is too big");
         mask = primitivePolynomial;
         if (direction == 'l')
            shiftRight (&mask, 1);
         else
            clearbit (&mask, degree);

         position = strchr (position, ',');
         if (position)
            {
            INTEGER advanceCount;
            scanDigits (position + 1, &advanceCount, 10);
            if (direction == 'r')
               for (index = 0; index < advanceCount.uintn [0]; index++)
                  nextRightFibonacciLfsr (&state, &mask, degree);
            else
               for (index = 0; index < advanceCount.uintn [0]; index++)
                  nextLeftFibonacciLfsr (&state, &mask, degree);
            printf ("%s\n", integerToAscii (&state, degree));
            }
         else // no count argument, dump entire sequence
            {
            start = state;
            elements = 1 << degree;
            // limit to partial list if the native integer overflows
            if (degree >= sizeof (elements) * 8) elements = ~0;
            if (direction == 'r')
               for (index = 0; index < elements; index++)
                  {
                  printf ("%s\n", integerToAscii (&state, degree));
                  nextRightFibonacciLfsr (&state, &mask, degree);
                  if (memcmp (&state, &start, sizeof state) == 0) break;
                  }
            else
               for (index = 0; index < elements; index++)
                  {
                  printf ("%s\n", integerToAscii (&state, degree));
                  nextLeftFibonacciLfsr (&state, &mask, degree);
                  if (memcmp (&state, &start, sizeof state) == 0) break;
                  }
            }
         }
      else if (isNumber (position))
         {
         INTEGER value;
         scanDigits (position, &value, 2);
         printf ("%s\n", integerToAscii (&value, 0));
         }

      else if (strcmp (position, "help") == 0)
         help ();
      else
         logError ("unknown command line argument: %s", position);
      }
   return 0;
   }

