2010-02-05 12 views
12

Quindi ho lavorato di recente a un'implementazione del test di primalità di Miller-Rabin. Lo sto limitando ad un ambito di tutti i numeri a 32 bit, perché questo è un progetto per me divertente che sto facendo per familiarizzare con C++, e non voglio dover lavorare con qualcosa a 64-bit per un po. Un ulteriore vantaggio è che l'algoritmo è deterministico per tutti i numeri a 32 bit, quindi posso aumentare significativamente l'efficienza perché so esattamente a quali testimoni fare test.Esponenziazione modulare per numeri elevati in C++

Quindi per i numeri bassi, l'algoritmo funziona eccezionalmente bene. Tuttavia, parte del processo si basa sull'esponenziazione modulare, ovvero (num^pow)% mod. così, ad esempio,

3^2 % 5 = 
9 % 5 = 
4 

Ecco il codice che ho usato per questo elevamento a potenza modulare:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod) 
{ 
    unsigned test; 
    for(test = 1; pow; pow >>= 1) 
    { 
     if (pow & 1) 
      test = (test * num) % mod; 
     num = (num * num) % mod; 
    } 

    return test; 

} 

Come forse avrete già indovinato, i problemi sorgono quando gli argomenti sono tutti eccezionalmente grandi numeri. Ad esempio, se voglio testare il numero 673.109 per la primalità, io a un certo punto dovrà trovare:

(2^168277)% 673109

ora 2^168.277 è un numero eccezionalmente elevato, e da qualche parte nel processo esso supera il test, il che si traduce in una valutazione errata.

sul retro, argomenti come

4000111222^3% 1608

anche valutare correttamente, per la stessa ragione.

Qualcuno ha suggerimenti per l'esponenziazione modulare in modo tale da impedire questo overflow e/o manipolarlo per produrre il risultato corretto? (Il mio modo di vedere, troppo pieno è solo un'altra forma di modulo, cioè num% (UINT_MAX + 1))

risposta

7

Exponentiation by squaring ancora "funziona" per modulo elevatore. Il tuo problema non è che 2^168277 è un numero eccezionalmente grande, è che uno dei tuoi risultati intermedi è un numero abbastanza grande (più grande di 2^32), perché 673109 è più grande di 2^16.

Quindi penso che quanto segue farà. È possibile che mi sia sfuggito un dettaglio, ma l'idea di base funziona, e questo è il modo in cui il codice crittografico "reale" potrebbe eseguire un mod-esponenziazione di grandi dimensioni (anche se non con numeri a 32 e 64 bit, piuttosto con bignum che non devono mai diventare più grandi di 2 * log (modulo)):

  • Iniziare con l'elevazione a potenza a squadrare, come si ha.
  • Esegue la quadratura effettiva in un numero intero senza segno a 64 bit.
  • Riduci il modulo 673109 a ogni passaggio per tornare all'interno della gamma a 32 bit, come fai tu.

Ovviamente questo è un po 'imbarazzante se l'implementazione del C++ non ha un numero intero a 64 bit, anche se si può sempre fingere uno.

C'è un esempio sulla diapositiva 22 qui: http://www.cs.princeton.edu/courses/archive/spr05/cos126/lectures/22.pdf, anche se utilizza numeri molto piccoli (meno di 2^16), quindi potrebbe non illustrare nulla che non si conosce già.

L'altro esempio, 4000111222^3 % 1608 funzionerebbe nel codice corrente se si riduce solo 4000111222 modulo 1608 prima di iniziare. 1608 è abbastanza piccolo da poter tranquillamente moltiplicare due numeri mod-1608 in un int a 32 bit.

+0

grazie amico, quello ha fatto il trucco. Solo per curiosità, conosci qualche metodo che non richiederebbe l'uso di una memoria più grande? Sono sicuro che sarebbero tornati utili. –

+0

Non che io sappia. Devi moltiplicare insieme due numeri fino a 673108, mod 673109. Ovviamente potresti rompere le cose e fare una lunga moltiplicazione con "cifre" più piccole, ad esempio 2^10. Ma non appena implementate la moltiplicazione e la divisione nel software, potreste anche implementarlo per il caso speciale di moltiplicare insieme due valori a 32 bit per ottenere un risultato a 64 bit, quindi dividendo per estrarre un resto a 32 bit. Potrebbero esserci alcune ottimizzazioni hard-core che fanno il minimo indispensabile, ma non le conosco e fingere un int a 64 bit in C++ non è * così * difficile. –

3

due cose:

  • Si sta utilizzando il tipo di dati appropriato? In altre parole, UINT_MAX ti consente di avere 673109 come argomento?

No, non, dal momento che ad un certo punto avete il vostro codice non funziona, perché a un certo punto si ha num = 2^16 e la num = ... causa di overflow. Utilizzare un tipo di dati più grande per mantenere questo valore intermedio.

  • ne dite di prendere modulo ad ogni possibile vedere sottacqua troppo pieno come ad esempio:

    test = ((test % mod) * (num % mod)) % mod;

Edit:

unsigned mod_pow(unsigned num, unsigned pow, unsigned mod) 
{ 
    unsigned long long test; 
    unsigned long long n = num; 
    for(test = 1; pow; pow >>= 1) 
    { 
     if (pow & 1) 
      test = ((test % mod) * (n % mod)) % mod; 
     n = ((n % mod) * (n % mod)) % mod; 
    } 

    return test; /* note this is potentially lossy */ 
} 

int main(int argc, char* argv[]) 
{ 

    /* (2^168277) % 673109 */ 
    printf("%u\n", mod_pow(2, 168277, 673109)); 
    return 0; 
} 
-1

È possibile utilizzare seguente identità:

(a * b) (mod m) === (a (mod m)) * (b (mod m)) (mod m)

Provare ad usarlo in modo semplice e migliorare in modo incrementale.

if (pow & 1) 
     test = ((test % mod) * (num % mod)) % mod; 
    num = ((num % mod) * (num % mod)) % mod; 
+1

grazie per entrambi i suggerimenti, ma dalla natura dell'algoritmo sia test che num saranno sempre inferiori a mod, quindi: {(test% mod) = test} e {(num% mod) = test} quindi l'identità non può aiutarmi perché la funzione fallisce anche quando num e test sono inferiori a mod. Inoltre, gli input non firmati mi consentono di avere 673109 come argomento. UINT_MAX = 4 294 967 295 per il mio computer. –

+0

Ho aggiunto il frammento di codice; per favore vedi voglio volevo dire. –

5

Ho scritto qualcosa per questo recentemente per RSA in C++, un po 'disordinato però.

#include "BigInteger.h" 
#include <iostream> 
#include <sstream> 
#include <stack> 

BigInteger::BigInteger() { 
    digits.push_back(0); 
    negative = false; 
} 

BigInteger::~BigInteger() { 
} 

void BigInteger::addWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) { 
    int sum_n_carry = 0; 
    int n = (int)a.digits.size(); 
    if (n < (int)b.digits.size()) { 
     n = b.digits.size(); 
    } 
    c.digits.resize(n); 
    for (int i = 0; i < n; ++i) { 
     unsigned short a_digit = 0; 
     unsigned short b_digit = 0; 
     if (i < (int)a.digits.size()) { 
      a_digit = a.digits[i]; 
     } 
     if (i < (int)b.digits.size()) { 
      b_digit = b.digits[i]; 
     } 
     sum_n_carry += a_digit + b_digit; 
     c.digits[i] = (sum_n_carry & 0xFFFF); 
     sum_n_carry >>= 16; 
    } 
    if (sum_n_carry != 0) { 
     putCarryInfront(c, sum_n_carry); 
    } 
    while (c.digits.size() > 1 && c.digits.back() == 0) { 
     c.digits.pop_back(); 
    } 
    //std::cout << a.toString() << " + " << b.toString() << " == " << c.toString() << std::endl; 
} 

void BigInteger::subWithoutSign(BigInteger& c, const BigInteger& a, const BigInteger& b) { 
    int sub_n_borrow = 0; 
    int n = a.digits.size(); 
    if (n < (int)b.digits.size()) 
     n = (int)b.digits.size(); 
    c.digits.resize(n); 
    for (int i = 0; i < n; ++i) { 
     unsigned short a_digit = 0; 
     unsigned short b_digit = 0; 
     if (i < (int)a.digits.size()) 
      a_digit = a.digits[i]; 
     if (i < (int)b.digits.size()) 
      b_digit = b.digits[i]; 
     sub_n_borrow += a_digit - b_digit; 
     if (sub_n_borrow >= 0) { 
      c.digits[i] = sub_n_borrow; 
      sub_n_borrow = 0; 
     } else { 
      c.digits[i] = 0x10000 + sub_n_borrow; 
      sub_n_borrow = -1; 
     } 
    } 
    while (c.digits.size() > 1 && c.digits.back() == 0) { 
     c.digits.pop_back(); 
    } 
    //std::cout << a.toString() << " - " << b.toString() << " == " << c.toString() << std::endl; 
} 

int BigInteger::cmpWithoutSign(const BigInteger& a, const BigInteger& b) { 
    int n = (int)a.digits.size(); 
    if (n < (int)b.digits.size()) 
     n = (int)b.digits.size(); 
    //std::cout << "cmp(" << a.toString() << ", " << b.toString() << ") == "; 
    for (int i = n-1; i >= 0; --i) { 
     unsigned short a_digit = 0; 
     unsigned short b_digit = 0; 
     if (i < (int)a.digits.size()) 
      a_digit = a.digits[i]; 
     if (i < (int)b.digits.size()) 
      b_digit = b.digits[i]; 
     if (a_digit < b_digit) { 
      //std::cout << "-1" << std::endl; 
      return -1; 
     } else if (a_digit > b_digit) { 
      //std::cout << "+1" << std::endl; 
      return +1; 
     } 
    } 
    //std::cout << "0" << std::endl; 
    return 0; 
} 

void BigInteger::multByDigitWithoutSign(BigInteger& c, const BigInteger& a, unsigned short b) { 
    unsigned int mult_n_carry = 0; 
    c.digits.clear(); 
    c.digits.resize(a.digits.size()); 
    for (int i = 0; i < (int)a.digits.size(); ++i) { 
     unsigned short a_digit = 0; 
     unsigned short b_digit = b; 
     if (i < (int)a.digits.size()) 
      a_digit = a.digits[i]; 
     mult_n_carry += a_digit * b_digit; 
     c.digits[i] = (mult_n_carry & 0xFFFF); 
     mult_n_carry >>= 16; 
    } 
    if (mult_n_carry != 0) { 
     putCarryInfront(c, mult_n_carry); 
    } 
    //std::cout << a.toString() << " x " << b << " == " << c.toString() << std::endl; 
} 

void BigInteger::shiftLeftByBase(BigInteger& b, const BigInteger& a, int times) { 
    b.digits.resize(a.digits.size() + times); 
    for (int i = 0; i < times; ++i) { 
     b.digits[i] = 0; 
    } 
    for (int i = 0; i < (int)a.digits.size(); ++i) { 
     b.digits[i + times] = a.digits[i]; 
    } 
} 

void BigInteger::shiftRight(BigInteger& a) { 
    //std::cout << "shr " << a.toString() << " == "; 
    for (int i = 0; i < (int)a.digits.size(); ++i) { 
     a.digits[i] >>= 1; 
     if (i+1 < (int)a.digits.size()) { 
      if ((a.digits[i+1] & 0x1) != 0) { 
       a.digits[i] |= 0x8000; 
      } 
     } 
    } 
    //std::cout << a.toString() << std::endl; 
} 

void BigInteger::shiftLeft(BigInteger& a) { 
    bool lastBit = false; 
    for (int i = 0; i < (int)a.digits.size(); ++i) { 
     bool bit = (a.digits[i] & 0x8000) != 0; 
     a.digits[i] <<= 1; 
     if (lastBit) 
      a.digits[i] |= 1; 
     lastBit = bit; 
    } 
    if (lastBit) { 
     a.digits.push_back(1); 
    } 
} 

void BigInteger::putCarryInfront(BigInteger& a, unsigned short carry) { 
    BigInteger b; 
    b.negative = a.negative; 
    b.digits.resize(a.digits.size() + 1); 
    b.digits[a.digits.size()] = carry; 
    for (int i = 0; i < (int)a.digits.size(); ++i) { 
     b.digits[i] = a.digits[i]; 
    } 
    a.digits.swap(b.digits); 
} 

void BigInteger::divideWithoutSign(BigInteger& c, BigInteger& d, const BigInteger& a, const BigInteger& b) { 
    c.digits.clear(); 
    c.digits.push_back(0); 
    BigInteger two("2"); 
    BigInteger e = b; 
    BigInteger f("1"); 
    BigInteger g = a; 
    BigInteger one("1"); 
    while (cmpWithoutSign(g, e) >= 0) { 
     shiftLeft(e); 
     shiftLeft(f); 
    } 
    shiftRight(e); 
    shiftRight(f); 
    while (cmpWithoutSign(g, b) >= 0) { 
     g -= e; 
     c += f; 
     while (cmpWithoutSign(g, e) < 0) { 
      shiftRight(e); 
      shiftRight(f); 
     } 
    } 
    e = c; 
    e *= b; 
    f = a; 
    f -= e; 
    d = f; 
} 

BigInteger::BigInteger(const BigInteger& other) { 
    digits = other.digits; 
    negative = other.negative; 
} 

BigInteger::BigInteger(const char* other) { 
    digits.push_back(0); 
    negative = false; 
    BigInteger ten; 
    ten.digits[0] = 10; 
    const char* c = other; 
    bool make_negative = false; 
    if (*c == '-') { 
     make_negative = true; 
     ++c; 
    } 
    while (*c != 0) { 
     BigInteger digit; 
     digit.digits[0] = *c - '0'; 
     *this *= ten; 
     *this += digit; 
     ++c; 
    } 
    negative = make_negative; 
} 

bool BigInteger::isOdd() const { 
    return (digits[0] & 0x1) != 0; 
} 

BigInteger& BigInteger::operator=(const BigInteger& other) { 
    if (this == &other) // handle self assignment 
     return *this; 
    digits = other.digits; 
    negative = other.negative; 
    return *this; 
} 

BigInteger& BigInteger::operator+=(const BigInteger& other) { 
    BigInteger result; 
    if (negative) { 
     if (other.negative) { 
      result.negative = true; 
      addWithoutSign(result, *this, other); 
     } else { 
      int a = cmpWithoutSign(*this, other); 
      if (a < 0) { 
       result.negative = false; 
       subWithoutSign(result, other, *this); 
      } else if (a > 0) { 
       result.negative = true; 
       subWithoutSign(result, *this, other); 
      } else { 
       result.negative = false; 
       result.digits.clear(); 
       result.digits.push_back(0); 
      } 
     } 
    } else { 
     if (other.negative) { 
      int a = cmpWithoutSign(*this, other); 
      if (a < 0) { 
       result.negative = true; 
       subWithoutSign(result, other, *this); 
      } else if (a > 0) { 
       result.negative = false; 
       subWithoutSign(result, *this, other); 
      } else { 
       result.negative = false; 
       result.digits.clear(); 
       result.digits.push_back(0); 
      } 
     } else { 
      result.negative = false; 
      addWithoutSign(result, *this, other); 
     } 
    } 
    negative = result.negative; 
    digits.swap(result.digits); 
    return *this; 
} 

BigInteger& BigInteger::operator-=(const BigInteger& other) { 
    BigInteger neg_other = other; 
    neg_other.negative = !neg_other.negative; 
    return *this += neg_other; 
} 

BigInteger& BigInteger::operator*=(const BigInteger& other) { 
    BigInteger result; 
    for (int i = 0; i < (int)digits.size(); ++i) { 
     BigInteger mult; 
     multByDigitWithoutSign(mult, other, digits[i]); 
     BigInteger shift; 
     shiftLeftByBase(shift, mult, i); 
     BigInteger add; 
     addWithoutSign(add, result, shift); 
     result = add; 
    } 
    if (negative != other.negative) { 
     result.negative = true; 
    } else { 
     result.negative = false; 
    } 
    //std::cout << toString() << " x " << other.toString() << " == " << result.toString() << std::endl; 
    negative = result.negative; 
    digits.swap(result.digits); 
    return *this; 
} 

BigInteger& BigInteger::operator/=(const BigInteger& other) { 
    BigInteger result, tmp; 
    divideWithoutSign(result, tmp, *this, other); 
    result.negative = (negative != other.negative); 
    negative = result.negative; 
    digits.swap(result.digits); 
    return *this; 
} 

BigInteger& BigInteger::operator%=(const BigInteger& other) { 
    BigInteger c, d; 
    divideWithoutSign(c, d, *this, other); 
    *this = d; 
    return *this; 
} 

bool BigInteger::operator>(const BigInteger& other) const { 
    if (negative) { 
     if (other.negative) { 
      return cmpWithoutSign(*this, other) < 0; 
     } else { 
      return false; 
     } 
    } else { 
     if (other.negative) { 
      return true; 
     } else { 
      return cmpWithoutSign(*this, other) > 0; 
     } 
    } 
} 

BigInteger& BigInteger::powAssignUnderMod(const BigInteger& exponent, const BigInteger& modulus) { 
    BigInteger zero("0"); 
    BigInteger one("1"); 
    BigInteger e = exponent; 
    BigInteger base = *this; 
    *this = one; 
    while (cmpWithoutSign(e, zero) != 0) { 
     //std::cout << e.toString() << " : " << toString() << " : " << base.toString() << std::endl; 
     if (e.isOdd()) { 
      *this *= base; 
      *this %= modulus; 
     } 
     shiftRight(e); 
     base *= BigInteger(base); 
     base %= modulus; 
    } 
    return *this; 
} 

std::string BigInteger::toString() const { 
    std::ostringstream os; 
    if (negative) 
     os << "-"; 
    BigInteger tmp = *this; 
    BigInteger zero("0"); 
    BigInteger ten("10"); 
    tmp.negative = false; 
    std::stack<char> s; 
    while (cmpWithoutSign(tmp, zero) != 0) { 
     BigInteger tmp2, tmp3; 
     divideWithoutSign(tmp2, tmp3, tmp, ten); 
     s.push((char)(tmp3.digits[0] + '0')); 
     tmp = tmp2; 
    } 
    while (!s.empty()) { 
     os << s.top(); 
     s.pop(); 
    } 
    /* 
    for (int i = digits.size()-1; i >= 0; --i) { 
     os << digits[i]; 
     if (i != 0) { 
      os << ","; 
     } 
    } 
    */ 
    return os.str(); 

E un esempio di utilizzo.

BigInteger a("87682374682734687"), b("435983748957348957349857345"), c("2348927349872344") 

// Will Calculate pow(87682374682734687, 435983748957348957349857345) % 2348927349872344 
a.powAssignUnderMod(b, c); 

Anche veloce, e ha un numero illimitato di cifre.

+0

grazie per aver condiviso! Una domanda, è una cifra std :: vector ? – darius

+0

Sì, ma lavorando in base 65536 sotto il cofano, non base 10. – clinux

1
package playTime; 

    public class play { 

     public static long count = 0; 
     public static long binSlots = 10; 
     public static long y = 645; 
     public static long finalValue = 1; 
     public static long x = 11; 

     public static void main(String[] args){ 

      int[] binArray = new int[]{0,0,1,0,0,0,0,1,0,1}; 

      x = BME(x, count, binArray); 

      System.out.print("\nfinal value:"+finalValue); 

     } 

     public static long BME(long x, long count, int[] binArray){ 

      if(count == binSlots){ 
       return finalValue; 
      } 

      if(binArray[(int) count] == 1){ 
       finalValue = finalValue*x%y; 
      } 

      x = (x*x)%y; 
      System.out.print("Array("+binArray[(int) count]+") " 
          +"x("+x+")" +" finalVal("+    finalValue + ")\n"); 

      count++; 


      return BME(x, count,binArray); 
     } 

    } 
+0

quello era il codice che ho scritto in java molto rapidamente. L'esempio che ho usato era 11^644mod 645. = 1. sappiamo che il binario di 645 è 1010000100. Ho una specie di codice e le variabili sono state codificate ma funzionano bene. – ShowLove

+0

output era Array (0) x (121) finalVal (1) Array (0) x (451) finalVal (1) Array (1) x (226) finalVal (451) Array (0) x (121) finalVal (451) Array (0) x (451) finalVal (451) Array (0) x (226) finalVal (451) Array (0) x (121) finalVal (451) Array (1) x (451) finalVal (391) Array (0) x (226) finalVal (391) Array (1) x (121) finalVal (1) valore finale: 1 – ShowLove

0

LL è per long long int

LL power_mod(LL a, LL k) { 
    if (k == 0) 
     return 1; 
    LL temp = power(a, k/2); 
    LL res; 

    res = ((temp % P) * (temp % P)) % P; 
    if (k % 2 == 1) 
     res = ((a % P) * (res % P)) % P; 
    return res; 
} 

utilizzare la funzione ricorsiva sopra per trovare l'exp mod del numero. Ciò non comporterà un overflow perché calcola in modo bottom-up.

Esempio di prova Per: a = 2 e k = 168277 mostra uscita sia 518.358 che è corretto e la funzione viene eseguita in O(log(k)) tempo;