[docs]defforward(self,x):""" x: [B, L, H] # 从 CNN / RNN 得到的结果 L 作为 input_num_capsules, H 作为 input_dim_capsule """B,I,_=x.size()# I 是 input_num_capsulesO,F=self.num_capsule,self.dim_capsuleu=torch.matmul(x,self.W)u=u.view(B,I,O,F).transpose(1,2)# [B, O, I, F]b=torch.zeros_like(u[:,:,:,0]).to(device=u.device)# [B, O, I]foriinrange(self.num_iterations):c=torch.softmax(b,dim=1)# [B, O_s, I]v=torch.einsum('boi,boif->bof',[c,u])# [B, O, F]v=self.squash(v)b=torch.einsum('bof,boif->boi',[v,u])# [B, O, I]returnv# [B, O, F] [B, num_capsule, dim_capsule]