Si noti che il calcolo di tale dovrà fare almeno ~ n × N = 173 miliardi operazioni (non considerando la simmetria), quindi sarà lenta a meno che NumPy ha accesso alla GPU o qualcosa del genere. Su un computer moderno con una CPU da ~ 3 GHz, l'intero calcolo dovrebbe richiedere circa 60 secondi, supponendo che SIMD/parallelo non accelerino.
Per il test, cominciamo con N = 1000. Useremo questo per controllare la correttezza e le prestazioni:
#!/usr/bin/env python3
import numpy
import time
numpy.random.seed(0)
n = 120
N = 1000
X = numpy.random.random((N, n))
start_time = time.time()
M3 = numpy.einsum('ij,ik,il->jkl', X, X, X)
end_time = time.time()
print('check:', M3[2,4,6], '= 125.401852515?')
print('check:', M3[4,2,6], '= 125.401852515?')
print('check:', M3[6,4,2], '= 125.401852515?')
print('check:', numpy.sum(M3), '= 218028826.631?')
print('total time =', end_time - start_time)
Questa operazione richiede circa 8 secondi. Questa è la linea di base.
Cominciamo con il ciclo nidificato 3 come l'alternativa:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
for l in range(n):
M3[j,k,l] = numpy.sum(X[:,j] * X[:,k] * X[:,l])
# ~27 seconds
Questo richiede circa mezzo minuto, non va bene! Uno dei motivi è dato dal fatto che si tratta in realtà di quattro cicli annidati: numpy.sum
può anche essere considerato un ciclo.
Notiamo che la somma può essere trasformato in un prodotto scalare per rimuovere questo 4 ° ciclo:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
for l in range(n):
M3[j,k,l] = X[:,j] * X[:,k] @ X[:,l]
# 14 seconds
molto meglio ora, ma ancora lento. Ma notiamo che il prodotto il punto può essere cambiata in una moltiplicazione di matrici per rimuovere un loop:
M3 = numpy.zeros((n, n, n))
for j in range(n):
for k in range(n):
M3[j,k] = X[:,j] * X[:,k] @ X
# ~0.5 seconds
Eh? Ora questo è persino molto più efficiente di einsum
! Potremmo anche verificare che la risposta sia effettivamente corretta.
Possiamo andare oltre? Sì! Potremmo eliminare l'anello k
da:
M3 = numpy.zeros((n, n, n))
for j in range(n):
Y = numpy.repeat(X[:,j], n).reshape((N, n))
M3[j] = (Y * X).T @ X
# ~0.3 seconds
ci potrebbe anche usare radiodiffusione (cioè a * [b,c] == [a*b, a*c]
per ogni fila di X) per evitare di fare il numpy.repeat
(grazie @Divakar):
M3 = numpy.zeros((n, n, n))
for j in range(n):
Y = X[:,j].reshape((N, 1))
## or, equivalently:
# Y = X[:, numpy.newaxis, j]
M3[j] = (Y * X).T @ X
# ~0.16 seconds
Se si scala questo a N = 100000 ci si aspetta che il programma impieghi 16 secondi, che è entro il limite teorico, quindi eliminare lo j
potrebbe non essere di grande aiuto (ma questo potrebbe rendere il codice davvero difficile da capire). Potremmo accettare come soluzione finale.
Nota: Se si sta utilizzando Python 2, a @ b
è equivalente a a.dot(b)
.
ottima risposta, grazie! –
Ottima idea davvero. Se posso aggiungere un po 'di trasmissione qui, potremmo evitare di creare 'Y' e ottenere direttamente l'output iterativo:' (X [:, None, j] * X) .T @ X'. Questo dovrebbe darci un ulteriore incremento delle prestazioni. – Divakar
@Divakar: Grazie! Aggiornato. – kennytm