2015-05-31 25 views
5

occorre inizializzare un tensore 3D con una funzione di indice-dipendente torch7, cioèmodo rapido per inizializzare un tensore in torch7

func = function(i,j,k) --i, j is the index of an element in the tensor 
    return i*j*k  --do operations within func which're dependent of i, j 
end 

poi inizializza un tensore 3D Un simili:

for i=1,A:size(1) do 
    for j=1,A:size(2) do 
     for k=1,A:size(3) do 
      A[{i,j,k}] = func(i,j,k) 
     end 
    end 
end 

Ma questo codice funziona molto lentamente e ho scoperto che occupa il 92% del tempo totale di esecuzione. Esistono modi più efficaci per inizializzare un tensore 3D nella torcia 7?

+0

Qual è la dimensione di 'A'? – ryanpattison

risposta

7

Vedere la documentazione per i Tensor:apply

Queste funzioni applicare una funzione a ciascun elemento del tensore cui il metodo viene chiamato (self). Questi metodi sono molto più veloci di utilizzando un ciclo for in Lua.

L'esempio nei documenti inizializza un array 2D basato sul suo indice i (in memoria). Di seguito è riportato un esempio esteso di 3 dimensioni e inferiore a quello dei tensori N-D. Utilizzando il metodo applicare è molto, molto più veloce sulla mia macchina:

require 'torch' 

A = torch.Tensor(100, 100, 1000) 
B = torch.Tensor(100, 100, 1000) 

function func(i,j,k) 
    return i*j*k  
end 

t = os.clock() 
for i=1,A:size(1) do 
    for j=1,A:size(2) do 
     for k=1,A:size(3) do 
      A[{i, j, k}] = i * j * k 
     end 
    end 
end 
print("Original time:", os.difftime(os.clock(), t)) 

t = os.clock() 
function forindices(A, func) 
    local i = 1 
    local j = 1 
    local k = 0 
    local d3 = A:size(3) 
    local d2 = A:size(2) 
    return function() 
    k = k + 1 
    if k > d3 then 
     k = 1 
     j = j + 1 
     if j > d2 then 
     j = 1 
     i = i + 1 
     end 
    end 
    return func(i, j, k) 
    end 
end 

B:apply(forindices(A, func)) 
print("Apply method:", os.difftime(os.clock(), t)) 

EDIT

Questo funziona per qualsiasi oggetto Tensor:

function tabulate(A, f) 
    local idx = {} 
    local ndims = A:dim() 
    local dim = A:size() 
    idx[ndims] = 0 
    for i=1, (ndims - 1) do 
    idx[i] = 1 
    end 
    return A:apply(function() 
    for i=ndims, 0, -1 do 
     idx[i] = idx[i] + 1 
     if idx[i] <= dim[i] then 
     break 
     end 
     idx[i] = 1 
    end 
    return f(unpack(idx)) 
    end) 
end 

-- usage for 3D case. 
tabulate(A, function(i, j, k) return i * j * k end) 
+0

@deltheil sì, grazie. – ryanpattison

+0

prego! (commento rimosso poiché non è più rilevante dopo questo [modifica] (http://stackoverflow.com/revisions/30560653/5)) – deltheil

+0

ottima risposta! Finché il functor può essere correttamente compilato con JIT, sarà molto veloce (vicino alla velocità C) – smhx

Problemi correlati