|
- import math
- import numpy as np
- import tensorflow as tf
-
- def warmup_cosine(x, warmup=0.002):
- s = tf.cast(x <= warmup, tf.float32)
- return s*(x/warmup) + (1-s)*(0.5 * (1 + tf.cos(math.pi * x)))
-
- def warmup_constant(x, warmup=0.002):
- s = tf.cast(x <= warmup, tf.float32)
- return s*(x/warmup) + (1-s)*1
-
- def warmup_linear(x, warmup=0.002):
- s = tf.cast(x <= warmup, tf.float32)
- return (s*(x/warmup) + (1-s))*(1-x)
-
- schedules = {
- 'warmup_cosine':warmup_cosine,
- 'warmup_constant':warmup_constant,
- 'warmup_linear':warmup_linear,
- }
-
- def adam(params, grads, lr, schedule, t_total, b1=0.9, b2=0.999, e=1e-8, l2=0, vector_l2=False, max_grad_norm=-1, **kwargs):
- """
- adam with weight decay fix
- """
- t = tf.Variable(0, dtype=tf.float32, trainable=False)
- tt = t+1
- updates = [t.assign(tt)]
- if max_grad_norm > 0:
- grads, _ = tf.clip_by_global_norm(grads, max_grad_norm)
- for p, g in zip(params, grads):
- if p is None or g is None:
- print("can't train", p.name, g)
- else:
- if isinstance(g, tf.IndexedSlices):
- g = tf.convert_to_tensor(g)
- m = tf.Variable(p*0, dtype=tf.float32, trainable=False)
- v = tf.Variable(p*0, dtype=tf.float32, trainable=False)
- lrt = lr*tf.sqrt(1-b2**tt)/(1-b1**tt)
- lrt *= schedule(t/t_total)
- mt = b1*m + (1-b1)*g
- vt = b2*v + (1-b2)*g*g
- if (len(p.get_shape()) > 1 or vector_l2) and l2 > 0:
- pt = p - lrt * (mt / (tf.sqrt(vt) + e) + l2*p)
- else:
- pt = p - lrt * (mt / (tf.sqrt(vt) + e))
- updates.extend([m.assign(mt), v.assign(vt), p.assign(pt)])
- return tf.group(*updates)
|