defbatch_data(x, batch_size=128): in_w = [] out_w = [] target = [] for text in x: for i in range(window, len(text) - window): word_set = set() in_w.append(text[i]) in_w.append(text[i]) in_w.append(text[i]) in_w.append(text[i])
with open( 'corpus.txt', encoding='utf-8') as fp: for line in fp: lines = re.sub("[^A-Za-z0-9']+", ' ', line).lower().split() line_id = [] for s in lines: ifnot s: continue if s notin vocab_dict: vocab_dict[s] = len(vocab_dict) id = vocab_dict[s] line_id.append(id) if id==11500: print(id,s) text_id.append(line_id) vocab_size = len(vocab_dict) print('vocab_size', vocab_size) model = SkipGram(vocab_size, embd_size).to(device)
for epoch in range(epochs): print('epoch', epoch) opt = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001, weight_decay=0) train(text_id, model,opt)