2016-06-04 12 views
7

Ho una matrice numpy di dimensioni (4, X, Y), in cui la prima dimensione rappresenta un quadruplo (R, G, B, A). Il mio obiettivo è quello di trasporre ogni quadruplet X*Y RGBA a valori a virgola mobile X*Y, dato un dizionario che li abbina.Miglioramento delle prestazioni dell'operazione di mappatura numpy

mio codice attuale è la seguente:

codeTable = { 
    (255, 255, 255, 127): 5.5, 
    (128, 128, 128, 255): 6.5, 
    (0 , 0 , 0 , 0 ): 7.5, 
} 

for i in range(0, rows): 
    for j in range(0, cols): 
     new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999) 

Dove data è una matrice di dimensione NumPy (4, rows, cols), e new_data è di dimensioni (rows, cols).

Il codice funziona correttamente, ma richiede molto tempo. Come dovrei ottimizzare quel pezzo di codice?

Ecco un esempio completo:

import numpy 

codeTable = { 
    (253, 254, 255, 127): 5.5, 
    (128, 129, 130, 255): 6.5, 
    (0 , 0 , 0 , 0 ): 7.5, 
} 

# test data 
rows = 2 
cols = 2 
data = numpy.array([ 
    [[253, 0], [128, 0], [128, 0]], 
    [[254, 0], [129, 144], [129, 0]], 
    [[255, 0], [130, 243], [130, 5]], 
    [[127, 0], [255, 120], [255, 5]], 
]) 

new_data = numpy.zeros((rows,cols), numpy.float32) 

for i in range(0, rows): 
    for j in range(0, cols): 
     new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999) 

# expected result for `new_data`: 
# array([[ 5.50000000e+00, 7.50000000e+00], 
#  [ 6.50000000e+00, -9.99900000e+03], 
#  [ 6.50000000e+00, -9.99900000e+03], dtype=float32) 
+0

Come ci sono molte 'righe' e' cols'? – Will

+0

@Will Molte migliaia per ciascuno. –

+0

Forse questo aiuterà: http://stackoverflow.com/questions/36480358/whats-a-fast-non-loop-way-to-apply-a-dict-to-a-array-mating-use-elements – hpaulj

risposta

1

Ecco un approccio che restituisce il risultato previsto, ma con una piccola quantità di dati, è difficile sapere se questo sarà più veloce per voi. Dal momento che ho evitato il doppio ciclo, tuttavia, immagino che vedrai un discreto aumento di velocità.

import numpy 
import pandas as pd 


codeTable = { 
    (253, 254, 255, 127): 5.5, 
    (128, 129, 130, 255): 6.5, 
    (0 , 0 , 0 , 0 ): 7.5, 
} 

# test data 
rows = 3 
cols = 2 
data = numpy.array([ 
    [[253, 0], [128, 0], [128, 0]], 
    [[254, 0], [129, 144], [129, 0]], 
    [[255, 0], [130, 243], [130, 5]], 
    [[127, 0], [255, 120], [255, 5]], 
]) 

new_data = numpy.zeros((rows,cols), numpy.float32) 

for i in range(0, rows): 
    for j in range(0, cols): 
     new_data[i,j] = codeTable.get(tuple(data[:,i,j]), -9999) 

def create_output(data): 
    # Reshape your two data sources to be a bit more sane 
    reshaped_data = data.reshape((4, -1)) 
    df = pd.DataFrame(reshaped_data).T 

    reshaped_codeTable = [] 
    for key in codeTable.keys(): 
     reshaped = list(key) + [codeTable[key]] 
     reshaped_codeTable.append(reshaped) 
    ct = pd.DataFrame(reshaped_codeTable) 

    # Merge on the data, replace missing merges with -9999 
    result = df.merge(ct, how='left') 
    newest_data = result[4].fillna(-9999) 

    # Reshape 
    output = newest_data.reshape(rows, cols) 
    return output 

output = create_output(data) 
print(output) 
# array([[ 5.50000000e+00, 7.50000000e+00], 
#  [ 6.50000000e+00, -9.99900000e+03], 
#  [ 6.50000000e+00, -9.99900000e+03]) 

print(numpy.array_equal(new_data, output)) 
# True 
+0

La tua soluzione sembra funzionare solo per dati di input quadrati e non funziona quando 'cols! = Rows'. Ma grazie per le idee, indagherò. Ad ogni modo, la velocità è molto più soddisfacente della mia ingenua soluzione a doppio anello. –

+0

corretto! Questo ora prenderà il numero richiesto di righe e colonne. –

+0

Bene, il tuo codice non funziona con altre forme di dati. Ho aggiornato il mio messaggio iniziale con un esempio più complicato. Il codice restituisce risultati corretti, ma nella posizione errata nell'array di output. –

1

Il pacchetto numpy_indexed (disclaimer: io sono il suo autore) contiene un vectorized ND-array variante capace di indice di una lista, che può essere utilizzato per risolvere il problema in modo efficace e conciso:

import numpy_indexed as npi 
map_keys = np.array(list(codeTable.keys())) 
map_values = np.array(list(codeTable.values())) 
indices = npi.indices(map_keys, data.reshape(4, -1).T, missing='mask') 
remapped = np.where(indices.mask, -9999, map_values[indices.data]).reshape(data.shape[1:]) 
+0

La tua soluzione sembra funzionare come un incantesimo. Grazie! Discuterò in seguito sui miglioramenti delle prestazioni. –

+0

In attesa di confronto delle prestazioni! –

Problemi correlati