# Copyright 2019 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import tensorflow as tf
from bert_dp.modeling import BertConfig, BertModel
from bert_dp.optimization import AdamWeightDecayOptimizer
from bert_dp.preprocessing import InputFeatures
from overrides import overrides
from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.common.registry import register
from deeppavlov.core.layers.tf_layers import bi_rnn
from deeppavlov.core.models.tf_model import LRScheduledTFModel
from deeppavlov.models.bert.bert_sequence_tagger import token_from_subtoken
log = getLogger(__name__)
[docs]class MTBertTask(ABC):
"""Abstract class for multitask BERT tasks. Objects of its subclasses are linked with BERT body when
``MultiTaskBert.build`` method is called. Training is performed with ``MultiTaskBert.train_on_batch`` method is
called. The objects of classes derived from ``MTBertTask`` don't have ``__call__`` method. Instead they have
``get_sess_run_infer_args`` and ``post_process_preds`` methods, which are called from ``call`` method of
``MultiTaskBert`` class. ``get_sess_run_infer_args`` method returns fetches and feed_dict for inference and
``post_process_preds`` method retrieves predictions from computed fetches. Classes derived from ``MTBertTask``
must ``get_sess_run_train_args`` method that returns fetches and feed_dict for training.
Args:
keep_prob: dropout keep_prob for non-BERT layers
return_probas: set this to ``True`` if you need the probabilities instead of raw answers
learning_rate: learning rate of BERT head
"""
def __init__(
self,
keep_prob: float = 1.,
return_probas: bool = None,
learning_rate: float = 1e-3,
):
self.keep_prob = keep_prob
self.return_probas = return_probas
self.init_head_learning_rate = learning_rate
self.min_body_learning_rate = None
self.head_learning_rate_multiplier = None
self.bert = None
self.optimizer_params = None
self.shared_ph = None
self.shared_feed_dict = None
self.sess = None
self.get_train_op_func = None
self.freeze_embeddings = None
self.bert_head_variable_scope = None
[docs] def build(
self,
bert_body: BertModel,
optimizer_params: Dict[str, Union[str, float]],
shared_placeholders: Dict[str, tf.placeholder],
sess: tf.Session,
mode: str,
get_train_op_func: Callable,
freeze_embeddings: bool,
bert_head_variable_scope: str) -> None:
"""Initiates building of the BERT head and initializes optimizer parameters, placeholders that are common for
all tasks.
Args:
bert_body: instance of ``BertModel``.
optimizer_params: a dictionary with four fields:
``'optimizer'`` (``str``) -- a name of optimizer class,
``'body_learning_rate'`` (``float``) -- initial value of BERT body learning rate,
``'min_body_learning_rate'`` (``float``) -- min BERT body learning rate for learning rate decay,
``'weight_decay_rate'`` (``float``) -- L2 weight decay for ``AdamWeightDecayOptimizer``
shared_placeholders: a dictionary with placeholders used in all tasks. The dictionary contains fields
``'input_ids'``, ``'input_masks'``, ``'learning_rate'``, ``'keep_prob'``, ``'is_train'``,
``'token_types'``.
sess: current ``tf.Session`` instance
mode: ``'train'`` or ``'inference'``
get_train_op_func: a function returning ``tf.Operation`` and with signature similar to
``LRScheduledTFModel.get_train_op`` without ``self`` argument. It is a function returning train
operation for specified loss and variable scopes.
freeze_embeddings: set ``False`` to train input embeddings.
bert_head_variable_scope: variable scope for BERT head.
"""
self.bert_head_variable_scope = bert_head_variable_scope
self.get_train_op_func = get_train_op_func
self.freeze_embeddings = freeze_embeddings
self.bert = bert_body
self.optimizer_params = optimizer_params
if mode == 'train':
self.head_learning_rate_multiplier = \
self.init_head_learning_rate / self.optimizer_params['body_learning_rate']
else:
self.head_learning_rate_multiplier = 0
mblr = self.optimizer_params.get('min_body_learning_rate')
self.min_body_learning_rate = 0. if mblr is None else mblr
self.shared_ph = shared_placeholders
self.sess = sess
self._init_graph()
if mode == 'train':
self._init_optimizer()
[docs] @abstractmethod
def _init_graph(self) -> None:
"""Build BERT head, initialize task specific placeholders, create attributes containing output probabilities
and model loss. Optimizer initialized not in this method but in ``_init_optimizer``."""
pass
[docs] def get_train_op(self, loss: tf.Tensor, body_learning_rate: Union[tf.Tensor, float], **kwargs) -> tf.Operation:
"""Return operation for the task training. Head learning rate is calculated as a product of
``body_learning_rate`` and quotient of initial head learning rate and initial body learning rate.
Args:
loss: the task loss
body_learning_rate: the learning rate for the BERT body
Returns:
train operation for the task
"""
assert "learnable_scopes" not in kwargs, "learnable scopes unsupported"
# train_op for bert variables
kwargs['learnable_scopes'] = ('bert/encoder', 'bert/embeddings')
if self.freeze_embeddings:
kwargs['learnable_scopes'] = ('bert/encoder',)
learning_rate = body_learning_rate * self.head_learning_rate_multiplier
bert_train_op = self.get_train_op_func(loss, body_learning_rate, **kwargs)
# train_op for ner head variables
kwargs['learnable_scopes'] = (self.bert_head_variable_scope,)
head_train_op = self.get_train_op_func(loss, learning_rate, **kwargs)
return tf.group(bert_train_op, head_train_op)
def _init_optimizer(self) -> None:
with tf.variable_scope(self.bert_head_variable_scope):
with tf.variable_scope('Optimizer'):
self.global_step = tf.get_variable('global_step',
shape=[],
dtype=tf.int32,
initializer=tf.constant_initializer(0),
trainable=False)
# default optimizer for Bert is Adam with fixed L2 regularization
if self.optimizer_params.get('optimizer') is None:
self.train_op = \
self.get_train_op(
self.loss,
body_learning_rate=self.shared_ph['learning_rate'],
optimizer=AdamWeightDecayOptimizer,
weight_decay_rate=self.optimizer_params.get('weight_decay_rate', 1e-6),
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
optimizer_scope_name='Optimizer',
exclude_from_weight_decay=["LayerNorm",
"layer_norm",
"bias"])
else:
self.train_op = self.get_train_op(self.loss,
body_learning_rate=self.shared_ph['learning_rate'],
optimizer_scope_name='Optimizer')
if self.optimizer_params.get('optimizer') is None:
with tf.variable_scope('Optimizer'):
new_global_step = self.global_step + 1
self.train_op = tf.group(self.train_op, [self.global_step.assign(new_global_step)])
def _build_feed_dict(self, input_ids, input_masks, token_types, y=None, body_learning_rate=None):
sph = self.shared_ph
train = y is not None
feed_dict = {
sph['input_ids']: input_ids,
sph['input_masks']: input_masks,
sph['token_types']: token_types,
sph['is_train']: train,
}
if train:
feed_dict.update({
sph['learning_rate']: body_learning_rate,
self.y_ph: y,
sph['keep_prob']: self.keep_prob,
})
return feed_dict
[docs] def train_on_batch(self, *args, **kwargs) -> Dict[str, float]:
"""Trains the task on one batch. This method will work correctly if you override ``get_sess_run_train_args``
for your task.
Args:
kwargs: the keys are ``body_learning_rate`` and ``"in"`` and ``"in_y"`` params for the task.
Returns:
dictionary with calcutated task loss and body and head learning rates.
"""
fetches, feed_dict = self.get_sess_run_train_args(*args, **kwargs)
_, loss = self.sess.run(fetches, feed_dict=feed_dict)
return {f'{self.bert_head_variable_scope}_loss': loss,
f'{self.bert_head_variable_scope}_head_learning_rate':
float(kwargs['body_learning_rate']) * self.head_learning_rate_multiplier,
'bert_body_learning_rate': kwargs['body_learning_rate']}
[docs] @abstractmethod
def get_sess_run_infer_args(self, *args) -> Tuple[List[tf.Tensor], Dict[tf.placeholder, Any]]:
"""Returns fetches and feed_dict for inference. Fetches are lists of tensors and feed_dict is dictionary
with placeholder values required for fetches computation. The method is used inside ``MultiTaskBert``
``__call__`` method.
If ``self.return_probas`` is ``True`` fetches contains probabilities tensor and predictions tensor otherwise.
Overriding methods take task inputs as positional arguments.
ATTENTION! Let ``get_sess_run_infer_args`` method have ``n_x_args`` arguments. Then the order of first
``n_x_args`` arguments of ``get_sess_run_train_args`` method arguments has to match the order of
``get_sess_run_infer_args`` arguments.
Args:
args: task inputs.
Returns:
fetches and feed_dict
"""
pass
[docs] @abstractmethod
def get_sess_run_train_args(self, *args) -> Tuple[List[tf.Tensor], Dict[tf.placeholder, Any]]:
"""Returns fetches and feed_dict for task ``train_on_batch`` method.
Overriding methods take task inputs as positional arguments.
ATTENTION! Let ``get_sess_run_infer_args`` method have ``n_x_args`` arguments. Then the order of first
``n_x_args`` arguments of ``get_sess_run_train_args`` method arguments has to match the order of
``get_sess_run_infer_args`` arguments.
Args:
args: task inputs followed by expect outputs.
Returns:
fetches and feed_dict
"""
pass
[docs] @abstractmethod
def post_process_preds(self, sess_run_res: list) -> list:
"""Post process results of ``tf.Session.run`` called for task inference. Called from method
``MultiTaskBert.__call__``.
Args:
sess_run_res: computed fetches from ``get_sess_run_infer_args`` method
Returns:
post processed results
"""
pass
[docs]@register("mt_bert_seq_tagging_task")
class MTBertSequenceTaggingTask(MTBertTask):
"""BERT head for text tagging. It predicts a label for every token (not subtoken) in the text.
You can use it for sequence labelling tasks, such as morphological tagging or named entity recognition.
Objects of this class should be passed to the constructor of ``MultiTaskBert`` class in param ``tasks``.
Args:
n_tags: number of distinct tags
use_crf: whether to use CRF on top or not
use_birnn: whether to use bidirection rnn after BERT layers.
For NER and morphological tagging we usually set it to ``False`` as otherwise the model overfits
birnn_cell_type: the type of Bidirectional RNN. Either ``"lstm"`` or ``"gru"``
birnn_hidden_size: number of hidden units in the BiRNN layer in each direction
keep_prob: dropout keep_prob for non-Bert layers
encoder_dropout: dropout probability of encoder output layer
return_probas: set this to ``True`` if you need the probabilities instead of raw answers
encoder_layer_ids: list of averaged layers from Bert encoder (layer ids)
optimizer: name of ``tf.train.*`` optimizer or None for ``AdamWeightDecayOptimizer``
weight_decay_rate: L2 weight decay for ``AdamWeightDecayOptimizer``
learning_rate: learning rate of BERT head
"""
def __init__(
self,
n_tags: int = None,
use_crf: bool = None,
use_birnn: bool = False,
birnn_cell_type: str = 'lstm',
birnn_hidden_size: int = 128,
keep_prob: float = 1.,
encoder_dropout: float = 0.,
return_probas: bool = None,
encoder_layer_ids: List[int] = None,
learning_rate: float = 1e-3,
):
super().__init__(keep_prob, return_probas, learning_rate)
self.n_tags = n_tags
self.use_crf = use_crf
self.use_birnn = use_birnn
self.birnn_cell_type = birnn_cell_type
self.birnn_hidden_size = birnn_hidden_size
self.encoder_dropout = encoder_dropout
self.encoder_layer_ids = encoder_layer_ids
def _init_placeholders(self) -> None:
self.y_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='y_ph')
self.y_masks_ph = tf.placeholder(shape=(None, None),
dtype=tf.int32,
name='y_mask_ph')
self.encoder_keep_prob = tf.placeholder_with_default(1.0, shape=[], name='encoder_keep_prob_ph')
def _init_graph(self) -> None:
with tf.variable_scope(self.bert_head_variable_scope):
self._init_placeholders()
self.seq_lengths = tf.reduce_sum(self.y_masks_ph, axis=1)
layer_weights = tf.get_variable('layer_weights_',
shape=len(self.encoder_layer_ids),
initializer=tf.ones_initializer(),
trainable=True)
layer_mask = tf.ones_like(layer_weights)
layer_mask = tf.nn.dropout(layer_mask, self.encoder_keep_prob)
layer_weights *= layer_mask
# to prevent zero division
mask_sum = tf.maximum(tf.reduce_sum(layer_mask), 1.0)
layer_weights = tf.unstack(layer_weights / mask_sum)
# TODO: may be stack and reduce_sum is faster
units = sum(w * l for w, l in zip(layer_weights, self.encoder_layers()))
units = tf.nn.dropout(units, keep_prob=self.shared_ph['keep_prob'])
if self.use_birnn:
units, _ = bi_rnn(units,
self.birnn_hidden_size,
cell_type=self.birnn_cell_type,
seq_lengths=self.seq_lengths,
name='birnn')
units = tf.concat(units, -1)
# TODO: maybe add one more layer?
logits = tf.layers.dense(units, units=self.n_tags, name="output_dense")
self.logits = token_from_subtoken(logits, self.y_masks_ph)
# CRF
if self.use_crf:
transition_params = tf.get_variable('Transition_Params',
shape=[self.n_tags, self.n_tags],
initializer=tf.zeros_initializer())
log_likelihood, transition_params = \
tf.contrib.crf.crf_log_likelihood(self.logits,
self.y_ph,
self.seq_lengths,
transition_params)
loss_tensor = -log_likelihood
self._transition_params = transition_params
self.y_predictions = tf.argmax(self.logits, -1)
self.y_probas = tf.nn.softmax(self.logits, axis=2)
with tf.variable_scope("loss"):
tag_mask = self._get_tag_mask()
y_mask = tf.cast(tag_mask, tf.float32)
if self.use_crf:
self.loss = tf.reduce_mean(loss_tensor)
else:
self.loss = tf.losses.sparse_softmax_cross_entropy(labels=self.y_ph,
logits=self.logits,
weights=y_mask)
def _get_tag_mask(self) -> tf.Tensor:
"""
Returns: tag_mask,
a mask that selects positions corresponding to word tokens (not padding and ``CLS``)
"""
max_length = tf.reduce_max(self.seq_lengths)
one_hot_max_len = tf.one_hot(self.seq_lengths - 1, max_length)
tag_mask = tf.cumsum(one_hot_max_len[:, ::-1], axis=1)[:, ::-1]
return tag_mask
def encoder_layers(self):
"""
Returns: the output of BERT layers specified in ``self.encoder_layers_ids``
"""
return [self.bert.all_encoder_layers[i] for i in self.encoder_layer_ids]
@overrides
def _build_feed_dict(self, input_ids, input_masks, y_masks, y=None, body_learning_rate=None):
token_types = np.zeros(np.array(input_ids).shape)
sph = self.shared_ph
train = y is not None
feed_dict = super()._build_feed_dict(input_ids, input_masks, token_types, y, body_learning_rate)
if train:
feed_dict[self.encoder_keep_prob] = 1.0 - self.encoder_dropout
feed_dict[self.y_masks_ph] = y_masks
return feed_dict
def _decode_crf(self, feed_dict: Dict[tf.Tensor, np.ndarray]) -> List[np.ndarray]:
logits, trans_params, mask, seq_lengths = self.sess.run([self.logits,
self._transition_params,
self.y_masks_ph,
self.seq_lengths],
feed_dict=feed_dict)
# iterate over the sentences because no batching in viterbi_decode
y_pred = []
for logit, sequence_length in zip(logits, seq_lengths):
logit = logit[:int(sequence_length)] # keep only the valid steps
viterbi_seq, viterbi_score = tf.contrib.crf.viterbi_decode(logit, trans_params)
y_pred += [viterbi_seq]
return y_pred
[docs] def get_sess_run_infer_args(
self,
input_ids: Union[List[List[int]], np.ndarray],
input_masks: Union[List[List[int]], np.ndarray],
y_masks: Union[List[List[int]], np.ndarray],
) -> Tuple[List[tf.Tensor], Dict[tf.placeholder, Any]]:
"""Returns fetches and feed_dict for model inference. The method is called from ``MultiTaskBert.__call__``.
Args:
input_ids: indices of the subwords in vocabulary
input_masks: mask that determines where to attend and where not to
y_masks: mask which determines the first subword units in the the word
Returns:
list of fetches and feed_dict
"""
feed_dict = self._build_feed_dict(input_ids, input_masks, y_masks)
if self.return_probas:
fetches = self.y_probas
else:
if self.use_crf:
fetches = [self.logits, self._transition_params, self.y_masks_ph, self.seq_lengths]
else:
fetches = [self.y_predictions, self.seq_lengths]
return fetches, feed_dict
[docs] def get_sess_run_train_args(
self,
input_ids: Union[List[List[int]], np.ndarray],
input_masks: Union[List[List[int]], np.ndarray],
y_masks: Union[List[List[int]], np.ndarray],
y: Union[List[List[int]], np.ndarray],
body_learning_rate: float) -> Tuple[List[tf.Tensor], Dict[tf.placeholder, Any]]:
"""Returns fetches and feed_dict for model ``train_on_batch`` method.
Args:
input_ids: indices of the subwords in vocabulary
input_masks: mask that determines where to attend and where not to
y_masks: mask which determines the first subword units in the the word
y: indices of ground truth tags
body_learning_rate: learning rate for BERT body
Returns:
list of fetches and feed_dict
"""
feed_dict = self._build_feed_dict(input_ids, input_masks, y_masks, y=y, body_learning_rate=body_learning_rate)
fetches = [self.train_op, self.loss]
return fetches, feed_dict
[docs] def post_process_preds(self, sess_run_res: List[np.ndarray]) -> Union[List[List[int]], List[np.ndarray]]:
"""Decodes CRF if needed and returns predictions or probabilities.
Args:
sess_run_res: list of computed fetches gathered by ``get_sess_run_infer_args``
Returns:
predictions or probabilities depending on ``return_probas`` attribute
"""
if self.return_probas:
pred = sess_run_res
else:
if self.use_crf:
logits, trans_params, mask, seq_lengths = sess_run_res
pred = []
for logit, sequence_length in zip(logits, seq_lengths):
logit = logit[:int(sequence_length)] # keep only the valid steps
viterbi_seq, viterbi_score = tf.contrib.crf.viterbi_decode(logit, trans_params)
pred += [viterbi_seq]
else:
pred, seq_lengths = sess_run_res
pred = [p[:l] for l, p in zip(seq_lengths, pred)]
return pred
[docs]@register("mt_bert_classification_task")
class MTBertClassificationTask(MTBertTask):
"""Task for text classification.
It uses output from [CLS] token and predicts labels using linear transformation.
Args:
n_classes: number of classes
return_probas: set ``True`` if return class probabilities instead of most probable label needed
one_hot_labels: set ``True`` if one-hot encoding for labels is used
keep_prob: dropout keep_prob for non-BERT layers
multilabel: set ``True`` if it is multi-label classification
learning_rate: learning rate of BERT head
optimizer: name of ``tf.train.*`` optimizer or ``None`` for ``AdamWeightDecayOptimizer``
"""
def __init__(
self,
n_classes: int = None,
return_probas: bool = None,
one_hot_labels: bool = None,
keep_prob: float = 1.,
multilabel: bool = False,
learning_rate: float = 2e-5,
optimizer: str = "Adam",
):
super().__init__(keep_prob, return_probas, learning_rate)
self.n_classes = n_classes
self.one_hot_labels = one_hot_labels
self.multilabel = multilabel
if self.multilabel and not self.one_hot_labels:
raise RuntimeError('Use one-hot encoded labels for multilabel classification!')
if self.multilabel and not self.return_probas:
raise RuntimeError('Set return_probas to True for multilabel classification!')
def _init_placeholders(self):
if not self.one_hot_labels:
self.y_ph = tf.placeholder(shape=(None,), dtype=tf.int32, name='y_ph')
else:
self.y_ph = tf.placeholder(shape=(None, self.n_classes), dtype=tf.float32, name='y_ph')
def _init_graph(self):
with tf.variable_scope(self.bert_head_variable_scope):
self._init_placeholders()
output_layer = self.bert.get_pooled_output()
hidden_size = output_layer.shape[-1].value
output_weights = tf.get_variable(
"output_weights", [self.n_classes, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [self.n_classes], initializer=tf.zeros_initializer())
with tf.variable_scope("loss"):
output_layer = tf.nn.dropout(output_layer, keep_prob=self.shared_ph['keep_prob'])
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
if self.one_hot_labels:
one_hot_labels = self.y_ph
else:
one_hot_labels = tf.one_hot(self.y_ph, depth=self.n_classes, dtype=tf.float32)
self.y_predictions = tf.argmax(logits, axis=-1)
if not self.multilabel:
log_probs = tf.nn.log_softmax(logits, axis=-1)
self.y_probas = tf.nn.softmax(logits, axis=-1)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
self.loss = tf.reduce_mean(per_example_loss)
else:
self.y_probas = tf.nn.sigmoid(logits)
self.loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(labels=one_hot_labels, logits=logits))
[docs] def get_sess_run_train_args(
self,
features: List[InputFeatures],
y: Union[List[int], List[List[int]]],
body_learning_rate: float) -> Tuple[List[tf.Tensor], Dict[tf.placeholder, Any]]:
"""Returns fetches and feed_dict for model ``train_on_batch`` method.
Args:
features: text features created by BERT preprocessor.
y: batch of labels (class id or one-hot encoding)
body_learning_rate: learning rate for BERT body
Returns:
list of fetches and feed_dict
"""
input_ids = [f.input_ids for f in features]
input_masks = [f.input_mask for f in features]
input_type_ids = [f.input_type_ids for f in features]
feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids, y=y,
body_learning_rate=body_learning_rate)
fetches = [self.train_op, self.loss]
return fetches, feed_dict
[docs] def get_sess_run_infer_args(
self,
features: List[InputFeatures]) -> Tuple[List[tf.Tensor], Dict[tf.placeholder, Any]]:
"""Returns fetches and feed_dict for model inference. The method is called from ``MultiTaskBert.__call__``.
Args:
features: text features created by BERT preprocessor.
Returns:
list of fetches and feed_dict
"""
input_ids = [f.input_ids for f in features]
input_masks = [f.input_mask for f in features]
input_type_ids = [f.input_type_ids for f in features]
feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids)
fetches = self.y_probas if self.return_probas else self.y_predictions
return fetches, feed_dict
[docs] def post_process_preds(self, sess_run_res):
"""Returns ``tf.Session.run`` results for inference without changes."""
return sess_run_res
[docs]@register('mt_bert')
class MultiTaskBert(LRScheduledTFModel):
"""The component for multi-task BERT. It builds the BERT body, launches building of BERT heads.
The component aggregates components implementing BERT heads. The head components are called tasks.
``__call__`` and ``train_on_batch`` methods of ``MultiTaskBert`` are used for inference and training of
BERT heads. BERT head components, which are derived from ``MTBertTask``, can be used only inside this class.
One training iteration consists of one ``train_on_batch`` call for every task.
If ``inference_task_names`` is not ``None``, then the component is created for training. Otherwise, the
component is created for inference. If component is created for inference, several tasks can be run
simultaneously. For explanation see parameter ``inference_task_names`` description.
Args:
tasks: a dictionary. Task names are dictionary keys and objects of ``MTBertTask`` subclasses are dictionary
values. Task names are used as variable scopes in computational graph so it is important to use same
names in multi-task BERT train and inference configuration files.
bert_config_file: path to BERT configuration file
pretrained_bert: pre-trained BERT checkpoint
attention_probs_keep_prob: keep_prob for BERT self-attention layers
hidden_keep_prob: keep_prob for BERT hidden layers
body_learning_rate: learning rate of BERT body
min_body_learning_rate: min value of body learning rate if learning rate decay is used
learning_rate_drop_patience: how many validations with no improvements to wait
learning_rate_drop_div: the divider of the learning rate after ``learning_rate_drop_patience`` unsuccessful
validations
load_before_drop: whether to load best model before dropping learning rate or not
clip_norm: clip gradients by norm
freeze_embeddings: set to False to train input embeddings
inference_task_names: names of tasks on which inference is done.
If this parameter is provided, the component is created for inference, else the component is created for
training.
If ``inference_task_names`` is a string, then it is a name of the task called separately from other tasks
(in individual ``tf.Session.run`` call).
If ``inference_task_names`` is a ``list``, then elements of this list are either strings or lists of
strings. You can combine these options. For example, ``["task_name1", ["task_name2", "task_name3"],
["task_name4", "task_name5"]]``.
If an element of ``inference_task_names`` list is a string, the element is a name of the task that is
computed when ``__call__`` method is called.
If an element of the ``inference_task_names`` parameter is a list of strings
``["task_name1", "task_name2", ...]``, then tasks ``"task_name1"``, ``"task_name2"`` and so on are run
simultaneously in ``tf.Session.run`` call. This option is available if tasks ``"task_name1"``,
``"task_name2"`` and so on have common inputs. Despite the fact that tasks share inputs, if positional
arguments are used in methods ``__call__`` and ``train_on_batch``, all arguments are passed individually.
For instance, if ``"task_name1"``, ``"task_name2"``, and ``"task_name3"`` all take an argument with name
``x`` in the model pipe, then the ``__call__`` method takes arguments ``(x, x, x)``.
in_distribution: The distribution of variables listed in the ``"in"`` config parameter between tasks.
``in_distribution`` can be ``None`` if only 1 task is called. In that case all variables
listed in ``"in"`` are arguments of 1 task.
``in_distribution`` can be a dictionary of ``int``. If that is the case, then keys of ``in_distribution``
are task names and values are numbers of variables from ``"in"`` parameter of config which are inputs of
corresponding task. The variables in ``"in"`` parameter have to be in the same order the tasks are listed
in ``in_distribution``.
``in_distribution`` can be a dictionary of lists of ``str``. Strings are names of variables from ``"in"``
configuration parameter. If ``"in"`` parameter is a list, then ``in_distribution`` works the same way as
when ``in_distribution`` is dictionary of ``int``. Values of ``in_distribution``, which are lists, are
replaced by their lengths. If ``"in"`` parameter in component config is a dictionary, then the order of
strings in ``in_distribution`` values has to match the order of arguments of ``train_on_batch`` and
``get_sess_run_infer_args`` methods of task components.
in_y_distribution: The same as ``in_distribution`` for ``"in_y"`` config parameter.
"""
def __init__(self,
tasks: Dict[str, MTBertTask],
bert_config_file: str,
pretrained_bert: str = None,
attention_probs_keep_prob: float = None,
hidden_keep_prob: float = None,
optimizer: str = None,
weight_decay_rate: float = 1e-6,
body_learning_rate: float = 1e-3,
min_body_learning_rate: float = 1e-7,
learning_rate_drop_patience: int = 20,
learning_rate_drop_div: float = 2.0,
load_before_drop: bool = True,
clip_norm: float = 1.0,
freeze_embeddings: bool = True,
inference_task_names: Optional[Union[str, List[Union[List[str], str]]]] = None,
in_distribution: Optional[Dict[str, Union[int, List[str]]]] = None,
in_y_distribution: Optional[Dict[str, Union[int, List[str]]]] = None,
**kwargs) -> None:
super().__init__(learning_rate=body_learning_rate,
learning_rate_drop_div=learning_rate_drop_div,
learning_rate_drop_patience=learning_rate_drop_patience,
load_before_drop=load_before_drop,
clip_norm=clip_norm,
**kwargs)
self.optimizer_params = {
"optimizer": optimizer,
"body_learning_rate": body_learning_rate,
"min_body_learning_rate": min_body_learning_rate,
"weight_decay_rate": weight_decay_rate
}
self.freeze_embeddings = freeze_embeddings
self.tasks = tasks
if inference_task_names is not None and isinstance(inference_task_names, str):
inference_task_names = [inference_task_names]
self.inference_task_names = inference_task_names
self.mode = 'train' if self.inference_task_names is None else 'inference'
self.shared_ph = None
self.bert_config = BertConfig.from_json_file(str(expand_path(bert_config_file)))
if attention_probs_keep_prob is not None:
self.bert_config.attention_probs_dropout_prob = 1.0 -attention_probs_keep_prob
if hidden_keep_prob is not None:
self.bert_config.hidden_dropout_prob = 1.0 - hidden_keep_prob
self.sess_config = tf.ConfigProto(allow_soft_placement=True)
self.sess_config.gpu_options.allow_growth = True
self.sess = tf.Session(config=self.sess_config)
self._init_bert_body_graph()
self.build_tasks()
self.sess.run(tf.global_variables_initializer())
if pretrained_bert is not None:
pretrained_bert = str(expand_path(pretrained_bert))
if tf.train.checkpoint_exists(pretrained_bert) \
and not (self.load_path and tf.train.checkpoint_exists(str(self.load_path.resolve()))) \
and self.mode == 'train':
log.info('[initializing model with Bert from {}]'.format(pretrained_bert))
var_list = self._get_saveable_variables(
exclude_scopes=('Optimizer', 'learning_rate', 'momentum') + tuple(self.tasks.keys()))
saver = tf.train.Saver(var_list)
saver.restore(self.sess, pretrained_bert)
if self.load_path is not None:
self.load()
self.in_distribution = in_distribution
self.in_y_distribution = in_y_distribution
def build_tasks(self):
def get_train_op(*args, **kwargs):
return self.get_train_op(*args, **kwargs)
for task_name, task_obj in self.tasks.items():
task_obj.build(
bert_body=self.bert,
optimizer_params=self.optimizer_params,
shared_placeholders=self.shared_ph,
sess=self.sess,
mode=self.mode,
get_train_op_func=get_train_op,
freeze_embeddings=self.freeze_embeddings,
bert_head_variable_scope=task_name
)
def _init_shared_placeholders(self) -> None:
self.shared_ph = {
'input_ids': tf.placeholder(shape=(None, None),
dtype=tf.int32,
name='token_indices_ph'),
'input_masks': tf.placeholder(shape=(None, None),
dtype=tf.int32,
name='token_mask_ph'),
'learning_rate': tf.placeholder_with_default(0.0, shape=[], name='learning_rate_ph'),
'keep_prob': tf.placeholder_with_default(1.0, shape=[], name='keep_prob_ph'),
'is_train': tf.placeholder_with_default(False, shape=[], name='is_train_ph')}
self.shared_ph['token_types'] = tf.placeholder_with_default(
tf.zeros_like(self.shared_ph['input_ids'], dtype=tf.int32),
shape=self.shared_ph['input_ids'].shape,
name='token_types_ph')
def _init_bert_body_graph(self) -> None:
self._init_shared_placeholders()
sph = self.shared_ph
self.bert = BertModel(config=self.bert_config,
is_training=sph['is_train'],
input_ids=sph['input_ids'],
input_mask=sph['input_masks'],
token_type_ids=sph['token_types'],
use_one_hot_embeddings=False)
def save(self, exclude_scopes=('Optimizer', 'learning_rate', 'momentum')) -> None:
return super().save(exclude_scopes=exclude_scopes)
def load(self,
exclude_scopes=('Optimizer',
'learning_rate',
'momentum'),
**kwargs) -> None:
return super().load(exclude_scopes=exclude_scopes, **kwargs)
[docs] def train_on_batch(self, *args, **kwargs) -> Dict[str, Dict[str, float]]:
"""Calls ``train_on_batch`` methods for every task. This method takes ``args`` or ``kwargs`` but not both.
The order of ``args`` is the same as the order of tasks in the component parameters:
.. highlight:: python
.. code-block:: python
args = [
task1_in_x[0],
task1_in_x[1],
task1_in_x[2],
...
task1_in_y[0],
task1_in_y[1],
...
task2_in_x[0],
...
]
If ``kwargs`` are used and ``in_distribution`` and ``in_y_distribution`` attributes are dictionaries of lists
of strings, then keys of ``kwargs`` have to be same as strings in ``in_distribution`` and
``in_y_distribution``. If ``in_distribution`` and ``in_y_distribution`` are dictionaries of ``int``, then
``kwargs`` values are treated the same way as ``args``.
Args:
args: task inputs and expected outputs
kwargs: task inputs and expected outputs
Returns:
dictionary of dictionaries with task losses and learning rates.
"""
# TODO: test passing arguments as args
if args and kwargs:
raise ValueError("You can use either args or kwargs not both")
n_in = sum([len(inp) if isinstance(inp, list) else inp for inp in self.in_distribution.values()])
if args:
args_in, args_in_y = args[:n_in], args[n_in:]
in_by_tasks = self._distribute_arguments_by_tasks(args_in, {}, list(self.tasks.keys()), "in")
in_y_by_tasks = self._distribute_arguments_by_tasks(args_in_y, {}, list(self.tasks.keys()), "in_y")
else:
kwargs_in, kwargs_in_y = {}, {}
for i, (k, v) in enumerate(kwargs.items()):
if i < n_in:
kwargs_in[k] = v
else:
kwargs_in_y[k] = v
in_by_tasks = self._distribute_arguments_by_tasks({}, kwargs_in, list(self.tasks.keys()), "in")
in_y_by_tasks = self._distribute_arguments_by_tasks({}, kwargs_in_y, list(self.tasks.keys()), "in_y")
train_on_batch_results = {}
for task_name, task in self.tasks.items():
train_on_batch_results.update(
task.train_on_batch(
*in_by_tasks[task_name],
*in_y_by_tasks[task_name],
body_learning_rate=max(self.get_learning_rate(), self.optimizer_params['min_body_learning_rate'])
)
)
for k, v in train_on_batch_results.items():
train_on_batch_results[k] = float(f"{v:.3}")
return train_on_batch_results
@staticmethod
def _unite_task_feed_dicts(d1, d2, task_name):
d = copy.copy(d1)
for k, v in d2.items():
if k in d:
comp = v != d[k]
if isinstance(comp, np.ndarray):
comp = comp.any()
if comp:
raise ValueError(
f"Value of placeholder '{k}' for task '{task_name}' does not match value of this placeholder "
"in other tasks")
else:
d[k] = v
return d
def _distribute_arguments_by_tasks(self, args, kwargs, task_names, what_to_distribute, in_distribution=None):
if args and kwargs:
raise ValueError("You may use args or kwargs but not both")
if what_to_distribute == "in":
if in_distribution is not None:
distribution = in_distribution
else:
distribution = self.in_distribution
elif what_to_distribute == "in_y":
if in_distribution is not None:
raise ValueError(
f"If parameter `what_to_distribute` is 'in_y', parameter `in_distribution` has to be `None`. "
f"in_distribution = {in_distribution}")
distribution = self.in_y_distribution
else:
raise ValueError(f"`what_to_distribute` can be 'in' or 'in_y', {repr(what_to_distribute)} is given")
if distribution is None:
if len(task_names) != 1:
raise ValueError(f"If no `{what_to_distribute}_distribution` is not provided there have to be only 1"
"task for inference")
return {task_names[0]: list(kwargs.values()) if kwargs else list(args)}
if all([isinstance(task_distr, int) for task_distr in distribution.values()]):
ints = True
elif all([isinstance(task_distr, list) for task_distr in distribution.values()]):
ints = False
else:
raise ConfigError(
f"Values of `{what_to_distribute}_distribution` attribute of `MultiTaskBert` have to be "
f"either `int` or `list` not both. "
f"{what_to_distribute}_distribution = {distribution}")
args_by_task = {}
flattened = []
for task_name in task_names:
if isinstance(task_name, str):
flattened.append(task_name)
else:
flattened.extend(task_name)
task_names = flattened
if args and not ints:
ints = True
distribution = {task_name: len(in_distr) for task_name, in_distr in distribution.items()}
if ints:
if kwargs:
values = list(kwargs.values())
else:
values = args
n_distributed = sum([n_args for n_args in distribution.values()])
if len(values) != n_distributed:
raise ConfigError(
f"The number of '{what_to_distribute}' arguments of MultitaskBert does not match "
f"the number of distributed params according to '{what_to_distribute}_distribution' parameter. "
f"{len(values)} parameters are in '{what_to_distribute}' and {n_distributed} parameters are "
f"required '{what_to_distribute}_distribution'. "
f"{what_to_distribute}_distribution = {distribution}")
values_taken = 0
for task_name in task_names:
args_by_task[task_name] = {}
n_args = distribution[task_name]
args_by_task[task_name] = [values[i] for i in range(values_taken, values_taken + n_args)]
values_taken += n_args
else:
assert kwargs
arg_names_used = []
for task_name in task_names:
in_distr = distribution[task_name]
args_by_task[task_name] = {}
args_by_task[task_name] = [kwargs[arg_name] for arg_name in in_distr]
arg_names_used += in_distr
set_used = set(arg_names_used)
set_all = set(kwargs.keys())
if set_used != set_all:
raise ConfigError(f"There are unused '{what_to_distribute}' parameters {set_all - set_used}")
return args_by_task
[docs] def __call__(self, *args, **kwargs):
"""Calls one or several BERT heads depending on provided task names. ``args`` and ``kwargs`` contain
inputs of BERT tasks. ``args`` and ``kwargs`` cannot be used together. If ``args`` are used ``args`` content
has to be
.. code-block:: python
args = [
task1_in_x[0],
task1_in_x[1],
...
task2_in_x[0],
task2_in_x[1],
...
]
If ``kwargs`` are used and ``in_distribution`` is a dictionary of ``int``, then ``kwargs``' order has to be
the same as ``args`` order described in the previous paragraph. If ``in_distribution`` is a dictionary of
lists of ``str``, then all task names from ``in_distribution`` have to be present in ``kwargs`` keys.
Returns:
list of results of called tasks.
"""
if self.inference_task_names is None:
task_names = list(self.tasks.keys())
else:
task_names = self.inference_task_names
if not task_names:
raise ValueError("No tasks to call")
if args and kwargs:
raise ValueError("You may use either args or kwargs not both")
return self.call(args, kwargs, task_names)
[docs] def call(
self,
args: Tuple[Any],
kwargs: Dict[str, Any],
task_names: Optional[Union[List[str], str]],
in_distribution: Optional[Union[Dict[str, int], Dict[str, List[str]]]] = None,
):
"""Calls one or several BERT heads depending on provided task names in ``task_names`` parameter. ``args`` and
``kwargs`` contain inputs of BERT tasks. ``args`` and ``kwargs cannot be used simultaneously. If ``args`` are
used ``args``, content has to be
.. code-block:: python
args = [
task1_in_x[0],
task1_in_x[1],
...
task2_in_x[0],
task2_in_x[1],
...
]
If ``kwargs`` is used ``kwargs`` keys has to match content of ``in_names`` params of called tasks.
Args:
args: generally, ``args`` parameter of ``__call__`` method of this component or ``MTBertReUser``. Inputs
of one or several tasks. Has to be empty if ``kwargs`` argument is used.
kwargs: generally, ``kwargs`` parameter of ``__call__`` method of this component or ``MTBertReUser``.
Inputs of one or several tasks. Has to be empty if ``args`` argument is used.
task_names: names of tasks that are called. If ``str``, then 1 task is called. If a task name is an
element of ``task_names`` list, then this task is run independently. If task an element of
``task_names`` is an list of strings, then tasks in the inner list are run simultaneously.
in_distribution: a distribution of variables from ``"in"`` config parameters between tasks. For details
see method ``__init__`` docstring.
Returns:
list results of called tasks.
"""
args_by_task = self._distribute_arguments_by_tasks(args, kwargs, task_names, "in", in_distribution)
results = []
task_count = 0
for elem in task_names:
if isinstance(elem, str):
task_count += 1
task = self.tasks[elem]
fetches, feed_dict = task.get_sess_run_infer_args(*args_by_task[elem])
sess_run_res = self.sess.run(fetches, feed_dict=feed_dict)
results.append(task.post_process_preds(sess_run_res))
else:
fetches = []
for task_name in elem:
task_count += 1
feed_dict = {}
task_fetches, task_feed_dict = self.tasks[task_name].get_sess_run_infer_args(
*args_by_task[task_name])
fetches.append(task_fetches)
feed_dict = self._unite_task_feed_dicts(feed_dict, task_feed_dict, task_name)
sess_run_res = self.sess.run(fetches, feed_dict=feed_dict)
for task_name, srs in zip(elem, sess_run_res):
task_results = self.tasks[task_name].post_process_preds(srs)
results.append(task_results)
if task_count == 1:
results = results[0]
return results
[docs]@register("mt_bert_reuser")
class MTBertReUser:
"""Instances of this class are for multi-task BERT inference. In inference config ``MultiTaskBert`` class may
not perform inference of some tasks. For example, you may need to sequentially apply two models with BERT.
In that case, ``mt_bert_reuser`` is created to call remaining tasks.
Args:
mt_bert: An instance of ``MultiTaskBert``
task_names: Names of infered tasks. If ``task_names`` is ``str``, then ``task_names`` is the name of the only
infered task. If ``task_names`` is ``list``, then its elements can be either strings or lists of strings.
If an element of ``task_names`` is a string, then this element is a name of a task that is run
independently. If an element of ``task_names`` is a list of strings, then the element is a list of names
of tasks that have common inputs and run simultaneously. For detailed information look up
``MultiTaskBert`` ``inference_task_names`` parameter.
"""
def __init__(
self,
mt_bert: MultiTaskBert,
task_names: Union[str, List[Union[str, List[str]]]],
in_distribution: Union[Dict[str, int], Dict[str, List[str]]] = None,
*args,
**kwargs):
self.mt_bert = mt_bert
if isinstance(task_names, str):
task_names = [task_names]
elif not task_names:
raise ValueError("At least 1 task has to specified")
self.task_names = task_names
flattened = []
for elem in self.task_names:
if isinstance(elem, str):
flattened.append(elem)
else:
flattened.extend(elem)
if in_distribution is None:
if len(flattened) > 1:
raise ValueError(
"If ``in_distribution`` parameter is not provided, there has to be only 1 task."
f"task_names = {self.task_names}")
self.in_distribution = in_distribution
[docs] def __call__(self, *args, **kwargs) -> List[Any]:
"""Infer tasks listed in parameter ``task_names``. One of parameters ``args`` and ``kwargs`` has to be empty.
Args:
args: inputs and labels of infered tasks.
kwargs: inputs and labels of infered tasks.
Returns:
list of results of inference of tasks listed in ``task_names``
"""
res = self.mt_bert.call(args, kwargs, task_names=self.task_names, in_distribution=self.in_distribution)
return res