Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Dump directory for prototyping and testing purposes
dump/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
2 changes: 1 addition & 1 deletion src/model/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def generate(self, src_enc, src_len, max_len=200, sample_temperature=None):

# add <EOS> to unfinished sentences
if cur_len == max_len:
generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index)
generated[-1].masked_fill_(unfinished_sents.bool(), self.eos_index)

# sanity check
assert (generated == self.eos_index).sum() == 2 * bs
Expand Down
2 changes: 1 addition & 1 deletion src/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def generate(self, src_enc, src_len, max_len=200, sample_temperature=None):

# add <EOS> to unfinished sentences
if cur_len == max_len:
generated[-1].masked_fill_(unfinished_sents.byte(), self.eos_index)
generated[-1].masked_fill_(unfinished_sents.bool(), self.eos_index)

# sanity check
assert (generated == self.eos_index).sum() == 2 * bs
Expand Down
8 changes: 4 additions & 4 deletions src/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def step(self, closure=None):
# grad.add_(group['weight_decay'], p.data)

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(1 - beta1, grad)
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
denom = exp_avg_sq.sqrt().add_(group['eps'])
# denom = exp_avg_sq.sqrt().clamp_(min=group['eps'])

Expand All @@ -84,9 +84,9 @@ def step(self, closure=None):
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

if group['weight_decay'] != 0:
p.data.add_(-group['weight_decay'] * group['lr'], p.data)
p.data.add_(p.data, alpha=-group['weight_decay'] * group['lr'])

p.data.addcdiv_(-step_size, exp_avg, denom)
p.data.addcdiv_(exp_avg, denom, value=-step_size)

return loss

Expand Down