einsum 알아보기

2021. 12. 7. 22:30꿀팁 분석 환경 설정/파이썬 개발 팁

728x90

목차

     

     

     

    개념

     

    가끔 보면 einsum을 홍보하는 글을 보거나 코드에서 본 것 같다.

    처음에는 대수롭지 않게 생각을 했는데, 좀 더 자세히 보니 매력적인 부분이 많은 것 같아 알아보려고 한다. 

    특히 매력적인 부분은 이것만 알고 있으면 numpy , pytorch, tensorflow 다 동일하게 적용할 수 있다는 점이다. 

    선형 대수학은 딥러닝 분야에서 근본적인 역할을 하는데, 아직까지는 춘추전국시대라서 통일된 라이브러리가 없고 계속 생겨나는 시점이라 이럴수록 하나로 통일해주는 것을 배우면 좋을 것 같다고 생각했다.

    그리고 잘만쓰면, 복잡한 연산도 쉽게 구현할 수 있는 것 같아서 좋은 것 같다.

    einsum 연산을 통해서, 행렬 내적, 외적, 내적, 행렬곱 등을 동일한 형태로 할 수 있다는 것이 참 매력적인 부분이다.

     

    왼쪽 : einsum 적용전, 오른쪽 einsum 적용 후 

     

    einstein 표기법과 einsum 함수로 다 동일한 포맷으로 할 수 있다.

    np.einsum(equation, *operands)
    torch.einsum(equation, *operands)
    tensorflow.einsum(equation, *operands)

     

    • equation
      • operand의 각 index에 대응하는 소문자로 구성되는 식입니다. 
      • -> 기준 왼쪽
        • operands 의 차원을 연결하는 부분 
        • `,` 를 기준으로 구분됨
      • -> 기존 오른쪽
        • output의 차원 인덱스
        • 생략되는 경우는 한번만 언급된 알파벳들을 순서대로 나열한 것으로 내부적으로 정의된다고 함.
      • 예시
        • ni, ij -> nj
          • i가 index에 해당
          • $\sum_i A_{ni} B_{ij}$
    • operands(tensor)
      • 연산을 수행할 대상들 
      • 1개, 2개 3개 이상도 가능함.

    아래 그림처럼 돌아간다고 생각하면 될 것 같다! 

     

     

    예시

    ## check 함수
    def simple_check_eisum(equation: str, operands , np_function_result):
        print(equation)
        print(operands)
        assert np.allclose(np.einsum(equation,*operands), np_function_result), "Numpy result is different from np function result"
        print(np.einsum(equation,*operands))

    최대한 numpy를 사용할 때와 einsum을 사용한 것과 함수 차이를 비교해보고자 한다.

    Transpose

    $A_{i, j}=B_{j, i}$

     

    mat2d = np.random.uniform(size=[2,2])
    print(mat2d)
    np.allclose(np.einsum("ij->ji",mat2d) , np.transpose(mat2d,(1,0)))

     

    $A_{i, j, k}=B_{k, j, i}$

    mat3d = np.random.uniform(size=[2,2,2])
    print(mat3d)
    np.einsum("ijk->kji",mat3d)
    
    np.allclose(np.einsum("ijk->kji",mat3d) , np.transpose(mat3d,(2,1,0)))
    np.allclose(np.einsum("ijk->jik",mat3d) , np.transpose(mat3d,(1,0,2)))

     

    Trace

    $\sum_i C_{i,i}$

     

    mat2d = np.random.uniform(size=[2,2])
    np.allclose(np.einsum("ii",mat2d) , np.trace(mat2d))

    Summation

    $b=\sum_i \sum_j A_{i, j} = A_{i, j}$ (ij->)

     

    $b_i=\sum_j A_{i,j} = A_{i,j}$ (ij->i)

     

    $b_j=\sum_i A_{i,j} = A_{i,j}$ (ij->j)

     

    mat2d = np.random.uniform(size=[2,2])
    simple_check_eisum("ij->",[mat2d],np.sum(mat2d))
    simple_check_eisum("ij->i",[mat2d],np.sum(mat2d,axis=1))
    simple_check_eisum("ij->j",[mat2d],np.sum(mat2d,axis=0))

     

    MATRIX VECTOR MULTIPLICATION

    $c_i = \sum_j A_{i, j} B_j = A_{ij} B_j$ (ij, j->i)

     

    test_matrix = np.arange(6).reshape([2,3])
    test_vector = np.arange(3)
    np.einsum("ij,j->i",*[test_matrix, test_vector])
    simple_check_eisum("ij,j->i",[test_matrix, test_vector],np.dot(test_matrix,test_vector))
    test_matrix_1 = np.arange(6).reshape([2,3])
    test_matrix_2 = np.arange(3)
    simple_check_eisum("ij,jk->ik",[test_matrix_1, test_vector_1[:,np.newaxis]],np.matmul(test_matrix_1, test_vector_1[:,np.newaxis]))

    MATRIX MATRIX MULTIPLICATION

    $c_{ij} = \sum_k A_{i, k} B_{k, j} = A_{ik} B_{kj}$ (ik, kj->ij)

    test_matrix = np.arange(6).reshape([2,3])
    test_vector = np.arange(3)
    simple_check_eisum("ik,kj->ij",[test_matrix, test_matrix.T],np.dot(test_matrix , test_matrix.T))

    $c_{ij} = \sum_k a_ik b_kj$

    test_matrix_1 = np.arange(6).reshape([2,3])
    test_matrix_2 = np.arange(12).reshape([3,4])
    simple_check_eisum("ij,kj->ik",[test_matrix_1, test_matrix_2],np.matmul(test_matrix_1 , test_matrix_2 ))

    $c_{ij} = \sum_k a_ik b_jk^T$

    test_matrix_1 = np.arange(6).reshape([2,3])
    test_matrix_2 = np.arange(12).reshape([4,3])
    simple_check_eisum("ij,kj->ik",[test_matrix_1, test_matrix_2],np.matmul(test_matrix_1 ,  np.transpose(test_matrix_2,(1,0))))

    DOT PRODUCT 

    (Vector)

    $c = \sum_i A_{i} B_{i} = A_{i} B_{i}$ (i, i->)

    test_vector_1 = np.arange(1,4)
    test_vector_2 = np.arange(3,6)
    simple_check_eisum("i,i->",[test_vector_1, test_vector_2],np.sum(test_vector_1 * test_vector_2))

    (Matrix)

    $c = \sum_i \sum_j A_{ij} B_{ij} = A_{ij} B_{ij}$ (ij, ij->)

    test_mat_1 = np.arange(6).reshape(2, 3)
    test_mat_2 = np.arange(6,12).reshape(2, 3)
    simple_check_eisum("ij,ij->",[test_mat_1, test_mat_2],np.sum(test_mat_1 * test_mat_2))

    OUTER PRODUCT 

    $c_{i, j} = a_i b_j$ (i, j->ij)

    simple_check_eisum("i,j->ij",[test_vector_1, test_vector_2],test_vector_1[:,np.newaxis] * test_vector_2[np.newaxis,:])

    $c_{j, i} = (a_i b_j). T$ (i, j->ji)

     

    simple_check_eisum("i,j->ji",[test_vector_1, test_vector_2],test_vector_2[:,np.newaxis] * test_vector_1[np.newaxis,:])

     

    HADAMARD PRODUCT 

    $c_{i, j} = (a_i b_j)$ (i, j->ij)

    test_matrix_1 = np.arange(6).reshape([2,3])
    test_matrix_2 = np.arange(6).reshape([2,3])
    simple_check_eisum("ij,ij->ij",[test_matrix_1, test_matrix_2],test_matrix_1*test_matrix_2)

    $c_{j, i} = (a_i b_j).T$ (i, j->ji)

    simple_check_eisum("ij,ij->ji",[test_matrix_1, test_matrix_2],(test_matrix_1*test_matrix_2).T)

     


    Batch Matrix Multiplication

    $c_{ijl} = \sum_k A_{ijk} B_{ikl} = A_{ijk} B_{ikl}$ (ijk, ikl->ikl)

     
    import torch
    i, j, k, l = 2, 1, 2, 3
    test_matrix_1 = np.random.uniform(size=(i,j,k))
    test_matrix_2 = np.random.uniform(size=(i,k,l))
    print(test_matrix_1.shape , test_matrix_2.shape)
    simple_check_eisum("ijk,ikl->ijl",[test_matrix_1, test_matrix_2],np.matmul(test_matrix_1 , test_matrix_2))
    simple_check_eisum("ijk,ikl->ijl",[test_matrix_1, test_matrix_2],torch.bmm(torch.tensor(test_matrix_1), torch.tensor(test_matrix_2)).numpy())

     

    test_matrix_3d = np.ones((3, 3, 3))
    test_matrix_2d = np.random.randint(0, 10, (3, 3))
    
    simple_check_eisum("BNi,Bi->BN",[test_matrix_3d, test_matrix_2d],np.matmul(test_matrix_3d, test_matrix_2d[:, :, None]).squeeze(-1))
    simple_check_eisum("BNi,Bi->BN",[test_matrix_3d, test_matrix_2d],(test_matrix_3d @ test_matrix_2d[:, :, None]).squeeze(-1))

     

    test_matrix_3d = np.ones((3, 3, 3))
    test_matrix_2d = np.random.randint(0, 10, (3, 1))
    simple_check_eisum("Bkj,Bl->Bjl",[test_matrix_3d, test_matrix_2d],np.matmul(test_matrix_3d,np.tile(test_matrix_2d,3)[:,:,None]))

     

    Bilinear Transformation

    i,j,k,l = 2,3,2,2
    test_matrix_1 = np.random.uniform(size=(i,k))
    test_matrix_2 = np.random.uniform(size=(i,l))
    np.einsum("ik,jkl,il->ij",*[test_matrix_1 , X , test_matrix_2])
    i,j,k,l = 2,3,2,2
    test_matrix_1 = np.random.uniform(size=(i,k))
    X = np.random.uniform(size=(j,k,l))
    np.einsum("ik,jkl->ijl",*[test_matrix_1 , X ])

     

     

    MultiHead Attention 

    batch_size, sequence_length, hidden_size, num_head = 2, 10, 16, 8
    hidden_states = np.random.uniform(size=(batch_size, sequence_length, hidden_size))
    hidden_states.shape # (2,10,16)
    
    W_K = np.random.uniform(size=(hidden_size, hidden_size))
    W_Q = np.random.uniform(size=(hidden_size, hidden_size))
    W_V = np.random.uniform(size=(hidden_size, hidden_size))
    head_hidden_size = hidden_size // num_head
    print(head_hidden_size) ## 2
    
    
    Q = np.einsum("ijk,kl->ijl",*[hidden_states, W_Q]) # [batch_size, sequence_length, hidden_size]
    K = np.einsum("ijk,kl->ijl",*[hidden_states, W_K])
    V = np.einsum("ijk,kl->ijl",*[hidden_states, W_V])
    print(Q.shape) # (2, 10, 16)
    
    print(np.reshape(Q,[batch_size,sequence_length,num_head,head_hidden_size]).shape)
    Q = np.reshape(Q, [batch_size, sequence_length, num_head, head_hidden_size]) # [batch_size, sequence_length, num_haed, head_hidden_size]
    K = np.reshape(K, [batch_size, sequence_length, num_head, head_hidden_size]) # [batch_size, sequence_length, num_haed, head_hidden_size]
    V = np.reshape(V, [batch_size, sequence_length, num_head, head_hidden_size])
    # (2,10,8,2)
    
    Q = np.einsum("ijkl->ikjl",Q) # [batch_size, num_haed, sequence_length, head_hidden_size]
    K = np.einsum("ijkl->ikjl",K)
    V = np.einsum("ijkl->ikjl",V)
    # (2,8,10,2)
    
    attention_score = np.einsum("ijkl,ijml->ijkm", Q, K)/np.sqrt(hidden_size)  # [batch_size, num_haed, sequence_length, sequence_length]
    attention_score.shape #(2,8,10,10)
    
    attention_result = np.einsum("ijkl,ijlm->ikjm", attention_score, V) # [batch_size, sequence_length, num_head, head_hidden_size]
    attention_result.shape # (2,10,8,2)
    attention_result = np.reshape(attention_result, [batch_size, sequence_length, hidden_size])
    attention_result.shape # (2,10,16)

     

    MultiHead Attention  (+einops) (head 1개인 경우)

    !pip install einops
    import torch
    from einops import rearrange
    from torch import nn 
    # b = 2 , t(token)= 128, dim=512 , 3 = (q,v,k)
    dim=512
    x = torch.randn(2,128,512)
    to_qvk = nn.Linear(dim, dim * 3, bias=False) # init only
    qkv = to_qvk(x)  # [batch, tokens, dim*3 ]
    # decomposition to q,v,k
    q, k, v = tuple(rearrange(qkv, 'b t (d k) -> k b t d ', k=3))
    scale_factor = np.sqrt(dim)
    scaled_dot_prod = torch.einsum('b i d , b j d -> b i j', q, k) * scale_factor
    attention = torch.softmax(scaled_dot_prod, dim=-1)
    attention_result = torch.einsum('b i j , b j d -> b i d', attention, v)
    attention_result.shape # (2,128,512)

     

    MultiHead Attention  (+einops) (head 여러 개인 경우)

    좀 더 간단하게 구현할 수 있다는 장점이 있음

    # b = 2 , t(token)= 128, dim=512 , 3 = (q,v,k)
    dim=512
    heads=8
    x = torch.randn(2,128,512)
    _dim = heads * dim 
    to_qvk = nn.Linear(dim, dim * heads * 3, bias=False) # init only
    qkv = to_qvk(x)
    q, k, v = tuple(rearrange(qkv, 'b t (d k h) -> k b h t d ', k=3, h=heads))
    print(q.shape) # [2, 8, 128, 512] [b,heads,token,dim]
    scale_factor = np.sqrt(dim)
    scaled_dot_prod = torch.einsum('b h i d , b h j d -> b h i j', q, k) * scale_factor  [b,heads,token,token]
    # if mask is not None:
    #     assert mask.shape == scaled_dot_prod.shape[2:]
    #     scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf)
    attention = torch.softmax(scaled_dot_prod, dim=-1)
    out = torch.einsum('b h i j , b h j d -> b h i d', attention, v) [b,heads,token,dim]
    out = rearrange(out, "b h t d -> b t (h d)") [b,token, heads*dim]
    W_0 = nn.Linear( _dim, dim, bias=False) # init only
    # Step 6. Apply final linear transformation layer 
    out = W_0(out) [b,token, dim]
    out.shape

     

     

    Reference

    6. https://newbedev.com/understanding-numpy-s-einsum

    728x90