2013-03-09 11 views
11

Ho scritto la funzione int compare_16bytes(__m128i lhs, __m128i rhs) per confrontare due numeri a 16 byte utilizzando le istruzioni SSE: questa funzione restituisce quanti byte sono uguali dopo aver eseguito il confronto.Conteggio veloce del numero di byte uguali tra due array

Ora vorrei utilizzare la funzione di cui sopra per confrontare due array di byte di lunghezza arbitraria: la lunghezza non può essere un multiplo di 16 byte, quindi ho bisogno di affrontare questo problema. Come posso completare l'implementazione della funzione qui sotto? Come posso migliorare la funzione qui sotto?

int fast_compare(const char* s, const char* t, int length) 
{ 
    int result = 0; 

    const char* sPtr = s; 
    const char* tPtr = t; 

    while(...) 
    { 
     const __m128i* lhs = (const __m128i*)sPtr; 
     const __m128i* rhs = (const __m128i*)tPtr; 

     // compare the next 16 bytes of s and t 
     result += compare_16bytes(*lhs,*rhs); 

     sPtr += 16; 
     tPtr += 16; 
    } 

    return result; 
} 
+2

Utilizzare un ciclo for (lunghezza/16 volte) e azzerare gli zeri su lhs e quelli su rh se i byte rimanenti sono inferiori a 16. Il padding dovrebbe essere diverso in modo che non contenga il padding in modo errato. –

+1

'while (length> = 16) {/ * usa la tua funzione */length - = 16; } if (length)/* usa una versione che confronta lunghezza (fino a 15) byte * /; ' – pmg

+1

FYI viene spesso chiamata [* Distanza di Hamming *] (http://en.wikipedia.org/wiki/Hamming_distance) - questo può essere utile come termine di ricerca. –

risposta

6

Come @Mysticial dice nei commenti di cui sopra, fare il confronto e sintesi in senso verticale e poi basta sommare orizzontalmente alla fine del ciclo principale:

#include <stdio.h> 
#include <stdlib.h> 
#include <time.h> 
#include <emmintrin.h> 

// reference implementation 
int fast_compare_ref(const char *s, const char *t, int length) 
{ 
    int result = 0; 
    int i; 

    for (i = 0; i < length; ++i) 
    { 
     if (s[i] == t[i]) 
      result++; 
    } 
    return result; 
} 

// optimised implementation 
int fast_compare(const char *s, const char *t, int length) 
{ 
    int result = 0; 
    int i; 

    __m128i vsum = _mm_set1_epi32(0); 
    for (i = 0; i < length - 15; i += 16) 
    { 
     __m128i vs, vt, v, vh, vl, vtemp; 

     vs = _mm_loadu_si128((__m128i *)&s[i]); // load 16 chars from input 
     vt = _mm_loadu_si128((__m128i *)&t[i]); 
     v = _mm_cmpeq_epi8(vs, vt);    // compare 
     vh = _mm_unpackhi_epi8(v, v);   // unpack compare result into 2 x 8 x 16 bit vectors 
     vl = _mm_unpacklo_epi8(v, v); 
     vtemp = _mm_madd_epi16(vh, vh);   // accumulate 16 bit vectors into 4 x 32 bit partial sums 
     vsum = _mm_add_epi32(vsum, vtemp); 
     vtemp = _mm_madd_epi16(vl, vl); 
     vsum = _mm_add_epi32(vsum, vtemp); 
    } 

    // get sum of 4 x 32 bit partial sums 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 8)); 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 4)); 
    result = _mm_cvtsi128_si32(vsum); 

    // handle any residual bytes (< 16) 
    if (i < length) 
    { 
     result += fast_compare_ref(&s[i], &t[i], length - i); 
    } 

    return result; 
} 

// test harness 
int main(void) 
{ 
    const int n = 1000000; 
    char *s = malloc(n); 
    char *t = malloc(n); 
    int i, result_ref, result; 

    srand(time(NULL)); 

    for (i = 0; i < n; ++i) 
    { 
     s[i] = rand(); 
     t[i] = rand(); 
    } 

    result_ref = fast_compare_ref(s, t, n); 
    result = fast_compare(s, t, n); 

    printf("result_ref = %d, result = %d\n", result_ref, result);; 

    return 0; 
} 

Compilare ed eseguire il test harness sopra:

$ gcc -Wall -O3 -msse3 fast_compare.c -o fast_compare 
$ ./fast_compare 
result_ref = 3955, result = 3955 
$ ./fast_compare 
result_ref = 3947, result = 3947 
$ ./fast_compare 
result_ref = 3945, result = 3945 

Nota che c'è un trucco forse non evidente nel codice SSE sopra dove usiamo _mm_madd_epi16 per spacchettare e accumulare 16 bit 0/-1 valori a somme parziali a 32 bit. Approfittiamo del fatto che lo -1*-1 = 1 (e lo 0*0 = 0 ovviamente) - qui non stiamo davvero facendo un multiplo, semplicemente disimballando e sommando in un'unica istruzione.


UPDATE: come notato nei commenti qui sotto, questa soluzione non è ottimale - Ho appena preso una soluzione abbastanza ottimale a 16 bit e ha aggiunto 8 bit a 16 bit disimballaggio per farlo funzionare per 8 bit di dati. Tuttavia per i dati a 8 bit esistono metodi più efficienti, ad es. utilizzando psadbw/_mm_sad_epu8. Lascerò qui questa risposta per i posteri e per chiunque voglia fare questo genere di cose con dati a 16 bit, ma in realtà una delle altre risposte che non richiede di decomprimere i dati di input dovrebbe essere la risposta accettata.

+0

Great! Funziona correttamente! Inoltre, è importante che i due vettori 's' e' t' siano _aligned_? Qual è l'allineamento? – enzom83

+1

Nell'esempio precedente ho usato '_mm_loadu_si128' in modo che non abbia importanza sull'allineamento. Se è possibile garantire che 's' e' t' siano allineati a 16 byte, allora usare '_mm_load_si128' invece di' _mm_loadu_si128' per prestazioni migliori, in particolare su CPU meno recenti. –

+0

_mm_setzero_si128() potrebbe essere più veloce di _mm_set1_epi32 (0) per l'azzeramento di vsum. – leecbaker

1

Il confronto in intero in SSE produce byte che sono tutti zeri o tutti. Se vuoi contare, devi prima spostare a destra (non aritmetico) il risultato del confronto per 7, quindi aggiungere al vettore risultato. Alla fine, è ancora necessario ridurre il vettore risultato sommando i suoi elementi. Questa riduzione deve essere eseguita in codice scalare o con una sequenza di add/shift. Di solito questa parte non vale la pena preoccuparsi.

3

L'utilizzo di somme parziali in 16 elementi uint8 può fornire prestazioni ancora migliori.
Ho diviso il ciclo in anello interno e anello esterno.
Il loop interno somma di uint8 elementi (ogni elemento uint8 può riassumere fino a 255 "1" s).
Piccolo trucco: _mm_cmpeq_epi8 imposta gli elementi uguali su 0xFF e (char) 0xFF = -1, quindi puoi sottrarre il risultato dalla somma (sottrarre -1 per aggiungere 1).

Ecco mia versione ottimizzata per fast_compare:

int fast_compare2(const char *s, const char *t, int length) 
{ 
    int result = 0; 
    int inner_length = length; 
    int i; 
    int j = 0; 

    //Points beginning of 4080 elements block. 
    const char *s0 = s; 
    const char *t0 = t; 


    __m128i vsum = _mm_setzero_si128(); 

    //Outer loop sum result of 4080 sums. 
    for (i = 0; i < length; i += 4080) 
    { 
     __m128i vsum_uint8 = _mm_setzero_si128(); //16 uint8 sum elements (each uint8 element can sum up to 255). 
     __m128i vh, vl, vhl, vhl_lo, vhl_hi; 

     //Points beginning of 4080 elements block. 
     s0 = s + i; 
     t0 = t + i; 

     if (i + 4080 <= length) 
     { 
      inner_length = 4080; 
     } 
     else 
     { 
      inner_length = length - i; 
     } 

     //Inner loop - sum up to 4080 (compared) results. 
     //Each uint8 element can sum up to 255. 16 uint8 elements can sum up to 255*16 = 4080 (compared) results. 
     ////////////////////////////////////////////////////////////////////////// 
     for (j = 0; j < inner_length-15; j += 16) 
     { 
       __m128i vs, vt, v; 

       vs = _mm_loadu_si128((__m128i *)&s0[j]); // load 16 chars from input 
       vt = _mm_loadu_si128((__m128i *)&t0[j]); 
       v = _mm_cmpeq_epi8(vs, vt);    // compare - set to 0xFF where equal, and 0 otherwise. 

       //Consider this: (char)0xFF = (-1) 
       vsum_uint8 = _mm_sub_epi8(vsum_uint8, v); //Subtract the comparison result - subtract (-1) where equal. 
     } 
     ////////////////////////////////////////////////////////////////////////// 

     vh = _mm_unpackhi_epi8(vsum_uint8, _mm_setzero_si128());  // unpack result into 2 x 8 x 16 bit vectors 
     vl = _mm_unpacklo_epi8(vsum_uint8, _mm_setzero_si128()); 
     vhl = _mm_add_epi16(vh, vl); //Sum high and low as uint16 elements. 

     vhl_hi = _mm_unpackhi_epi16(vhl, _mm_setzero_si128()); //unpack sum of vh an vl into 2 x 4 x 32 bit vectors 
     vhl_lo = _mm_unpacklo_epi16(vhl, _mm_setzero_si128()); //unpack sum of vh an vl into 2 x 4 x 32 bit vectors 

     vsum = _mm_add_epi32(vsum, vhl_hi); 
     vsum = _mm_add_epi32(vsum, vhl_lo); 
    } 

    // get sum of 4 x 32 bit partial sums 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 8)); 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 4)); 
    result = _mm_cvtsi128_si32(vsum); 

    // handle any residual bytes (< 16) 
    if (j < inner_length) 
    { 
     result += fast_compare_ref(&s0[j], &t0[j], inner_length - j); 
    } 

    return result; 
} 
+0

Heh, avrei dovuto guardare la nuova risposta prima di commentare su Paul; Ho suggerito la stessa cosa ('psubb' all'interno di un loop interno). Questo è ciò che intendevo, eccetto che dovresti usare 'psadbw' per fare la somma orizzontale di' vsum_uint8' (vedi i miei commenti sulla risposta di Paul). –

+0

Ho pensato di usare la somma orizzontale, ma ho deciso di mantenere la compatibilità SSE2. – Rotem

+0

Stai parlando di 'phaddd'? Non è quello che ho detto. 'phaddd' [il vantaggio è solo la dimensione del codice] (http: // stackoverflow.it/questions/6996764/quick-way-to-do-orizzontale-float-vector-sum-on-x86/35270026 # 35270026) sulle attuali CPU. Vedi anche la mia risposta a questa domanda, che usa solo le istruzioni SSE2. –

2

Il modo più veloce per grandi ingressi è la risposta di Rotem, dove il ciclo interno è pcmpeqb/psubb, all'esterno di somma orizzontalmente prima di qualsiasi elemento byte del vettore accumulatore trabocca. Esegui il byte di byte senza segno con psadbw contro un vettore tutto-zero.

Senza di svolgimento/cicli annidati, l'opzione migliore è probabilmente

pcmpeqb -> vector of 0 or 0xFF elements 
psadbw -> two 64bit sums of (0*no_matches + 0xFF*matches) 
paddq  -> accumulate the psadbw result in a vector accumulator 

#outside the loop: 
horizontal sum 
divide the result by 255 

Se non si dispone di un sacco di pressione registrati in loop, psadbw contro un vettore di 0x7f invece di tutti da zero.

  • psadbw(0x00, set1(0x7f)) =>sum += 0x7f
  • psadbw(0xff, set1(0x7f)) =>sum += 0x80

Così, invece di dividendo per 255 (che il compilatore dovrebbe fare in modo efficiente, senza un vero e proprio div), basta sottrarre n * 0x7f, dove n è il numero di elementi.

Si noti inoltre che paddq è lento sulla pre-Nehalem e Atom, così si potrebbe usare paddd (_mm_add_epi32) se non vi aspettate 128 * il conteggio overflow mai un intero a 32 bit.

Ciò confronta molto bene con del Paul R pcmpeqb/2x punpck/2x pmaddwd/2x paddw.

Problemi correlati