# -*- coding:utf-8 -*-
import argparse
import os
import sys
import numpy as np
import chainer
from chainer import optimizers
import chainer.functions as F
import chainer.links as L
import pickle
import dill

n_units = 512 # 隠れ層のユニット数

# LSTMのネットワーク定義
class LSTM(chainer.Chain):
state = {}

def __init__(self, n_vocab, n_units):
LSTM.__init__(
l1_embed = L.EmbedID(n_vocab, n_units),
l1_x = L.Linear(n_units, 4 * n_units),
l1_h = L.Linear(n_units, 4 * n_units),
l2_embed = L.EmbedID(n_vocab, n_units),
l2_x = L.Linear(n_units, 4 * n_units),
l2_h = L.Linear(n_units, 4 * n_units),
l3_embed = L.EmbedID(n_vocab, n_units),
l3_x = L.Linear(n_units, 4 * n_units),
l3_h = L.Linear(n_units, 4 * n_units),
l4_embed = L.EmbedID(n_vocab, n_units),
l4_x = L.Linear(n_units, 4 * n_units),
l4_h = L.Linear(n_units, 4 * n_units),
l5_embed = L.EmbedID(n_vocab, n_units),
l5_x = L.Linear(n_units, 4 * n_units),
l5_h = L.Linear(n_units, 4 * n_units),
l_umembed = L.Linear(n_units, n_vocab)
)

def __str__(self):
# return "{} {}".format(self.n_vocab, self.n_units)
return "{}".format(self.chainer.Chain)

def forward(self, x1, x2, x3, x4, x5, t, dropout_ratio=0.2):
h1 = self.l1_embed(chainer.Variable(np.asarray([x1])))
c1, y1 = F.lstm(chainer.Variable(np.zeros((1, n_units), dtype=np.float32)), F.dropout( self.l1_x(h1), ratio=dropout_ratio, train=False ) + self.l1_h(self.state['y1']))
h2 = self.l2_embed(chainer.Variable(np.asarray([x2])))
c2, y2 = F.lstm(self.state['c1'], F.dropout( self.l2_x(h2), ratio=dropout_ratio, train=False ) + self.l2_h(self.state['y2']))
h3 = self.l3_embed(chainer.Variable(np.asarray([x3])))
c3, y3 = F.lstm(self.state['c2'], F.dropout( self.l3_x(h3), ratio=dropout_ratio, train=False ) + self.l3_h(self.state['y3']))
h4 = self.l4_embed(chainer.Variable(np.asarray([x4])))
c4, y4 = F.lstm(self.state['c3'], F.dropout( self.l4_x(h4), ratio=dropout_ratio, train=False ) + self.l4_h(self.state['y4']))
h5 = self.l5_embed(chainer.Variable(np.asarray([x5])))
c5, y5 = F.lstm(self.state['c4'], F.dropout( self.l5_x(h5), ratio=dropout_ratio, train=False ) + self.l5_h(self.state['y5']))
self.state = {'c1': c1, 'y1': y1, 'h1': h1, 'c2': c2, 'y2': y2, 'h2': h2, 'c3': c3, 'y3': y3, 'h3': h3, 'c4': c4, 'y4': y4, 'h4': h4, 'c5': c5, 'y5': y5, 'h5': h5}
y = self.l_umembed(y5)
print('y:',vars(y))
print('t:', np.asarray([t]))

return F.softmax(y), y.data

def initialize_state(self, n_units, batchsize=1, train=True):
for name in ('c1', 'y1', 'h1', 'c2', 'y2', 'h2', 'c3', 'y3', 'h3', 'c4', 'y4', 'h4', 'c5', 'y5', 'h5'):
self.state[name] = chainer.Variable(np.zeros((batchsize, n_units), dtype=np.float32), volatile=not train)


def load_data(filename, vocab, inv_vocab):
#global vocab, inv_vocab
# 全文について改行を<eos>に変換し、単語毎に区切る
# 日本語の場合は、<eos> の両側にスペースを入れるreplace('
', ' <eos> ')
words = open(filename, encoding='utf-8').read().replace('
', ' <eos> ').strip().split()
dataset = np.ndarray((len(words),), dtype=np.int32)
for i, word in enumerate(words):
if word not in vocab:
vocab[word] = len(vocab)
inv_vocab[len(vocab)-1] = word
dataset[i] = vocab[word]
return dataset


def main():
''' main関数 '''
p = 5 # 文字列長
w = 2 # 前後の単語の数
total_loss = 0 # 誤差関数の値を入れる変数
vocab = {}
n_vocab = len(vocab)
inv_vocab={} #逆引き辞書

# 引数の処理
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', '-g', default=-1, type=int,
help='GPU ID (negative value indicates CPU)')
args = parser.parse_args()
# cuda環境では以下のようにすればよい
# if args.gpu >= 0:
# xp = cuda.cupy
# else:
# xp = np
# if args.gpu >= 0:
# cuda.get_device(args.gpu).use()
# model.to_gpu()
#with open('vocab.bin', 'rb') as fv:

fv = open('vocab.bin', 'rb')
vocab = pickle.load(fv)
print("vocab")
print(len(vocab))

# with open('inv_vocab.bin', 'rb') as fi:
fi = open('inv_vocab.bin', 'rb')
inv_vocab = pickle.load(fi)
print("inv_vocab")
print(len(inv_vocab))
# 訓練データ、評価データ、テストデータの読み込み
os.chdir('/Users/suguruoki/practice/chainer/examples/ptb/')
test_data = load_data('117-test-data.dat', vocab, inv_vocab)
print("test data loaded")
n_vocab = len(vocab)
print("n_vocab")
print(n_vocab)

# モデルの準備
# 入力は単語数、中間層はmain関数冒頭で定義
lstm = LSTM(n_vocab,n_units)
# lstm.initialize_state(n_units)

# model.compute_accuracy = False
model = L.Classifier(lstm) # こことpickleのところで違いがあるのでだめ
# print("model")
# print(model)
with open('LSTMmodel.pkl', 'rb') as fm:
model = pickle.load(fm)
# with open('LSTMlstm.pkl', 'rb') as fl:
# lstm = LSTM(n_vocab, n_units)
# lstm = pickle.load(open('LSTMlstm.pkl','rb'))
# print("testteset")
t = 2
seq = []; # seq: 周辺単語のリスト
for k in range(t-w,t+w+1):
if k >= 0:
if k == t:
seq.append(vocab['<$>'])
elif k > len(vocab)-1:
seq.append(vocab['<s>'])
else:
seq.append(test_data[k])
else:
seq.append(vocab['<s>'])
seq.append(test_data[t])
print('t =', t,', seq :', seq)
print(seq[0])
seq[0] = seq[0] if seq[0] < len(test_data) else 1
seq[1] = seq[1] if seq[1] < len(test_data) else 1
seq[2] = seq[2] if seq[2] < len(test_data) else 1
seq[3] = seq[3] if seq[3] < len(test_data) else 1
seq[4] = seq[4] if seq[4] < len(test_data) else 1
seq[5] = seq[5] if seq[5] < len(test_data) else 1


pr, y = lstm.forward(seq[0],seq[1],seq[2],seq[3],seq[4],seq[5],vocab)
prediction = list(zip(y[0].tolist(), inv_vocab.values()))
prediction.sort()
prediction.reverse()
print(prediction)

if __name__ == '__main__':
main()



LastUpdate: 2017/07/16 17:17
qr-code
ソーシャルブックマークにおすすめ PoolLink top
copyright (c) plsk.net