|
- """
- Multi-gate Mixture-of-Experts model implementation.
-
- Copyright (c) 2018 Drawbridge, Inc
- Licensed under the MIT License (see LICENSE for details)
- Written by Alvin Deng
- """
- from keras import backend as K
- from keras import activations, initializers, regularizers, constraints
- from keras.engine.topology import Layer, InputSpec # InputSpec指定图层的每个输入的ndim,dtype和形状.
- import tensorflow as tf
- from tensorflow_core.python.ops import init_ops
-
-
- class MMoE(Layer):
- """
- Multi-gate Mixture-of-Experts model.
- """
- def __init__(self,
- units,
- num_experts,
- num_tasks,
- use_expert_bias=True,
- use_gate_bias=True,
- expert_activation='relu',
- gate_activation='softmax',
- expert_bias_initializer='zeros',
- gate_bias_initializer='zeros',
- expert_bias_regularizer=None,
- gate_bias_regularizer=None,
- expert_bias_constraint=None,
- gate_bias_constraint=None,
- expert_kernel_initializer=init_ops.he_normal(seed=1), # 'VarianceScaling'
- gate_kernel_initializer=init_ops.he_normal(seed=1), # 'VarianceScaling'
- expert_kernel_regularizer=None,
- gate_kernel_regularizer=None,
- expert_kernel_constraint=None,
- gate_kernel_constraint=None,
- activity_regularizer=None,
- **kwargs):
- """
- Method for instantiating MMoE layer.
- :param units: Number of hidden units
- :param num_experts: Number of experts
- :param num_tasks: Number of tasks
- :param use_expert_bias: Boolean to indicate the usage of bias in the expert weights
- :param use_gate_bias: Boolean to indicate the usage of bias in the gate weights
- :param expert_activation: Activation function of the expert weights
- :param gate_activation: Activation function of the gate weights
- :param expert_bias_initializer: Initializer for the expert bias
- :param gate_bias_initializer: Initializer for the gate bias
- :param expert_bias_regularizer: Regularizer for the expert bias
- :param gate_bias_regularizer: Regularizer for the gate bias
- :param expert_bias_constraint: Constraint for the expert bias
- :param gate_bias_constraint: Constraint for the gate bias
- :param expert_kernel_initializer: Initializer for the expert weights
- :param gate_kernel_initializer: Initializer for the gate weights
- :param expert_kernel_regularizer: Regularizer for the expert weights
- :param gate_kernel_regularizer: Regularizer for the gate weights
- :param expert_kernel_constraint: Constraint for the expert weights
- :param gate_kernel_constraint: Constraint for the gate weights
- :param activity_regularizer: Regularizer for the activity
- :param kwargs: Additional keyword arguments for the Layer class
- """
- # Hidden nodes parameter
- self.units = units
- self.num_experts = num_experts
- self.num_tasks = num_tasks
-
- # Weight parameter
- self.expert_kernels = None
- self.gate_kernels = None
- self.expert_kernel_initializer = initializers.get(expert_kernel_initializer)
- self.gate_kernel_initializer = initializers.get(gate_kernel_initializer)
- self.expert_kernel_regularizer = regularizers.get(expert_kernel_regularizer)
- self.gate_kernel_regularizer = regularizers.get(gate_kernel_regularizer)
- self.expert_kernel_constraint = constraints.get(expert_kernel_constraint)
- self.gate_kernel_constraint = constraints.get(gate_kernel_constraint)
-
- # Activation parameter
- self.expert_activation = activations.get(expert_activation)
- self.gate_activation = activations.get(gate_activation)
-
- # Bias parameter
- self.expert_bias = None
- self.gate_bias = None
- self.use_expert_bias = use_expert_bias
- self.use_gate_bias = use_gate_bias
- self.expert_bias_initializer = initializers.get(expert_bias_initializer)
- self.gate_bias_initializer = initializers.get(gate_bias_initializer)
- self.expert_bias_regularizer = regularizers.get(expert_bias_regularizer)
- self.gate_bias_regularizer = regularizers.get(gate_bias_regularizer)
- self.expert_bias_constraint = constraints.get(expert_bias_constraint)
- self.gate_bias_constraint = constraints.get(gate_bias_constraint)
-
- # Activity parameter
- self.activity_regularizer = regularizers.get(activity_regularizer)
-
- # Keras parameter
- self.input_spec = InputSpec(min_ndim=2)
- self.supports_masking = True
-
- super(MMoE, self).__init__(**kwargs) # super() 函数是用于调用父类(超类)的一个方法
-
- def build(self, input_shape):
- """
- Method for creating the layer weights.
- :param input_shape: Keras tensor (future input to layer)
- or list/tuple of Keras tensors to reference
- for weight shape computations
- """
- assert input_shape is not None and len(input_shape) >= 2
-
- input_dimension = input_shape[-1]
-
- # Initialize expert weights (number of input features * number of units per expert * number of experts)
- self.expert_kernels = self.add_weight(
- name='expert_kernel',
- shape=(input_dimension, self.units, self.num_experts),
- initializer=self.expert_kernel_initializer,
- regularizer=self.expert_kernel_regularizer,
- constraint=self.expert_kernel_constraint,
- )
-
- # Initialize expert bias (number of units per expert * number of experts)
- if self.use_expert_bias:
- self.expert_bias = self.add_weight(
- name='expert_bias',
- shape=(self.units, self.num_experts),
- initializer=self.expert_bias_initializer,
- regularizer=self.expert_bias_regularizer,
- constraint=self.expert_bias_constraint
- )
-
- # Initialize gate weights (number of input features * number of experts * number of tasks)
- self.gate_kernels = [self.add_weight(
- name='gate_kernel_task_{}'.format(i),
- shape=(input_dimension, self.num_experts),
- initializer=self.gate_kernel_initializer,
- regularizer=self.gate_kernel_regularizer,
- constraint=self.gate_kernel_constraint
- ) for i in range(self.num_tasks)]
-
- # Initialize gate bias (number of experts * number of tasks)
- if self.use_gate_bias:
- self.gate_bias = [self.add_weight(
- name='gate_bias_task_{}'.format(i),
- shape=(self.num_experts,),
- initializer=self.gate_bias_initializer,
- regularizer=self.gate_bias_regularizer,
- constraint=self.gate_bias_constraint
- ) for i in range(self.num_tasks)]
-
- self.input_spec = InputSpec(min_ndim=2, axes={-1: input_dimension})
-
- super(MMoE, self).build(input_shape)
-
- def call(self, inputs, **kwargs):
- """
- Method for the forward function of the layer.
- :param inputs: Input tensor
- :param kwargs: Additional keyword arguments for the base method
- :return: A tensor
- """
- gate_outputs = []
- final_outputs = []
-
- # f_{i}(x) = activation(W_{i} * x + b), where activation is ReLU according to the paper
- expert_outputs = K.tf.tensordot(a=inputs, b=self.expert_kernels, axes=1)
- print("expert_kernels:", self.expert_kernels)
- # Add the bias term to the expert weights if necessary
- if self.use_expert_bias:
- expert_outputs = K.bias_add(x=expert_outputs, bias=self.expert_bias)
- print("expert_bias:", self.expert_bias)
- expert_outputs = self.expert_activation(expert_outputs)
-
- # g^{k}(x) = activation(W_{gk} * x + b), where activation is softmax according to the paper
- for index, gate_kernel in enumerate(self.gate_kernels):
- gate_output = K.dot(x=inputs, y=gate_kernel)
- # Add the bias term to the gate weights if necessary
- if self.use_gate_bias:
- gate_output = K.bias_add(x=gate_output, bias=self.gate_bias[index])
- gate_output = self.gate_activation(gate_output)
- gate_outputs.append(gate_output)
-
- # f^{k}(x) = sum_{i=1}^{n}(g^{k}(x)_{i} * f_{i}(x))
- for gate_output in gate_outputs:
- expanded_gate_output = K.expand_dims(gate_output, axis=1)
- weighted_expert_output = expert_outputs * K.repeat_elements(expanded_gate_output, self.units, axis=1)
- final_outputs.append(K.sum(weighted_expert_output, axis=2))
-
- return final_outputs
-
- def compute_output_shape(self, input_shape):
- """
- Method for computing the output shape of the MMoE layer.
- :param input_shape: Shape tuple (tuple of integers)
- :return: List of input shape tuple where the size of the list is equal to the number of tasks
- """
- assert input_shape is not None and len(input_shape) >= 2
-
- output_shape = list(input_shape)
- output_shape[-1] = self.units
- output_shape = tuple(output_shape)
-
- return [output_shape for _ in range(self.num_tasks)]
-
- def get_config(self):
- """
- Method for returning the configuration of the MMoE layer.
- :return: Config dictionary
- """
- config = {
- 'units': self.units,
- 'num_experts': self.num_experts,
- 'num_tasks': self.num_tasks,
- 'use_expert_bias': self.use_expert_bias,
- 'use_gate_bias': self.use_gate_bias,
- 'expert_activation': activations.serialize(self.expert_activation),
- 'gate_activation': activations.serialize(self.gate_activation),
- 'expert_bias_initializer': initializers.serialize(self.expert_bias_initializer),
- 'gate_bias_initializer': initializers.serialize(self.gate_bias_initializer),
- 'expert_bias_regularizer': regularizers.serialize(self.expert_bias_regularizer),
- 'gate_bias_regularizer': regularizers.serialize(self.gate_bias_regularizer),
- 'expert_bias_constraint': constraints.serialize(self.expert_bias_constraint),
- 'gate_bias_constraint': constraints.serialize(self.gate_bias_constraint),
- 'expert_kernel_initializer': initializers.serialize(self.expert_kernel_initializer),
- 'gate_kernel_initializer': initializers.serialize(self.gate_kernel_initializer),
- 'expert_kernel_regularizer': regularizers.serialize(self.expert_kernel_regularizer),
- 'gate_kernel_regularizer': regularizers.serialize(self.gate_kernel_regularizer),
- 'expert_kernel_constraint': constraints.serialize(self.expert_kernel_constraint),
- 'gate_kernel_constraint': constraints.serialize(self.gate_kernel_constraint),
- 'activity_regularizer': regularizers.serialize(self.activity_regularizer)
- }
- base_config = super(MMoE, self).get_config()
- return dict(list(base_config.items()) + list(config.items()))
|