Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
from mxnet.gluon import nn,rnn
from mxnet import nd
import mxnet as mx
from mxnet.gluon import loss as gloss, nn, rnn
from MultHeadAttention import Context_Attention
from mxnet import autograd, gluon, init, nd
def try_gpu():
# 本函数已保存在 d2lzh 包中⽅便以后使⽤
try:
ctx = mx.gpu()
_ = nd.zeros((1,), ctx=ctx)
except mx.base.MXNetError:
ctx = mx.cpu()
return ctx
class RnnModel(nn.Block):
def __init__(self,EMBEDDING_DIM,INPUT_DIM,LATENT_DIM,**kwargs):
super(RnnModel, self).__init__(**kwargs)
self.Embedding_dim=EMBEDDING_DIM
self.Input_DIM=INPUT_DIM
self.Latent_dim=LATENT_DIM
self.decoder_lstm=rnn.LSTM(LATENT_DIM)
self.context_attention=Context_Attention()
self.embedding=nd.Embedding
def begin_state(self, *args, **kwargs):
return self.decoder_lstm.begin_state(*args, **kwargs)
def forward(self, maxlen_output,maxlen_input,decoder_inputs,weight,encoder_output,s,c):
self.decoder_input = self.embedding(weight,decoder_inputs,self.Input_DIM,self.Embedding_dim)
outputs=[]
for i in range(maxlen_output):
context=self.context_attention(maxlen_input,encoder_output,s)
if i>=self.decoder_input.shape[0]:
selector=nn.Lambda(lambda x:x[[-1],:,:])
else:
selector=nn.Lambda(lambda x:x[[i],:,:])
x_t=selector(self.decoder_input)
decoder_lstm_input=nd.Concat(context,x_t,dim=0)
output,state=self.decoder_lstm(decoder_lstm_input,(s,c))
s,c=state
output= output.reshape((output.shape[0],output.shape[1]*output.shape[2]))
decoder_outputs=nd.softmax(output,axis=1)
decoder_outputs=nd.max(decoder_outputs,axis=0)
if i==0:
outputs=nd.max(decoder_outputs)
else:
outputs=nd.concat(outputs,nd.max(decoder_outputs),dim=0)
outputs.attach_grad()
return outputs
# if __name__ == '__main__':
# ctx=try_gpu()
# array=nd.random.normal(shape=(19,19))
# decoder_inputs=nd.array([[int(i) for i in range(j*19,j*19+19)] for j in range(19*19)],dtype="float32")
# weight=nd.array([[0,3,4,3,3,2,3,1,2,3,4,5,6,7,8,10,11,12,13]])
#
# r=RnnModel(EMBEDDING_DIM=19,INPUT_DIM=19*19,LATENT_DIM=19 )
# s, c = r.begin_state(batch_size=19, ctx=ctx)
# r.initialize()
# output=r(19*19,200,decoder_inputs,weight,array,s,c)
# print(output.shape)
# poss = nd.zeros(shape=(361,), dtype="float32")
# poss[0] = 0.0065
# poss[1] = 0.0025
# poss[25] = 0.00013
# poss[34] = 0.0012
# poss[44] = 0.0011
# poss[90] = 0.001023
# poss[100] = 0.00023
#
# s.detach()
# c.detach()
# loss = gloss.L2Loss()
# trainer = gluon.Trainer(r.collect_params(), 'sgd',
#
# {'learning_rate': 1e2, 'momentum': 0, 'wd': 0})
# print(r.collect_params())
# with autograd.record():
# l = loss(output, poss).sum()
# l.backward()
# trainer.step(1)