# Copyright 2017 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from typing import List, Dict, Union
import tensorflow as tf
from bert_dp.modeling import BertConfig, BertModel
from bert_dp.optimization import AdamWeightDecayOptimizer
from bert_dp.preprocessing import InputFeatures
from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.tf_model import LRScheduledTFModel
logger = getLogger(__name__)
[docs]@register('bert_classifier')
class BertClassifierModel(LRScheduledTFModel):
"""Bert-based model for text classification.
It uses output from [CLS] token and predicts labels using linear transformation.
Args:
bert_config_file: path to Bert configuration file
n_classes: number of classes
keep_prob: dropout keep_prob for non-Bert layers
one_hot_labels: set True if one-hot encoding for labels is used
multilabel: set True if it is multi-label classification
return_probas: set True if return class probabilites instead of most probable label needed
attention_probs_keep_prob: keep_prob for Bert self-attention layers
hidden_keep_prob: keep_prob for Bert hidden layers
optimizer: name of tf.train.* optimizer or None for `AdamWeightDecayOptimizer`
num_warmup_steps:
weight_decay_rate: L2 weight decay for `AdamWeightDecayOptimizer`
pretrained_bert: pretrained Bert checkpoint
min_learning_rate: min value of learning rate if learning rate decay is used
"""
# TODO: add warmup
# TODO: add head-only pre-training
def __init__(self, bert_config_file, n_classes, keep_prob,
one_hot_labels=False, multilabel=False, return_probas=False,
attention_probs_keep_prob=None, hidden_keep_prob=None,
optimizer=None, num_warmup_steps=None, weight_decay_rate=0.01,
pretrained_bert=None, min_learning_rate=1e-06, **kwargs) -> None:
super().__init__(**kwargs)
self.return_probas = return_probas
self.n_classes = n_classes
self.min_learning_rate = min_learning_rate
self.keep_prob = keep_prob
self.one_hot_labels = one_hot_labels
self.multilabel = multilabel
self.optimizer = optimizer
self.num_warmup_steps = num_warmup_steps
self.weight_decay_rate = weight_decay_rate
if self.multilabel and not self.one_hot_labels:
raise RuntimeError('Use one-hot encoded labels for multilabel classification!')
if self.multilabel and not self.return_probas:
raise RuntimeError('Set return_probas to True for multilabel classification!')
self.bert_config = BertConfig.from_json_file(str(expand_path(bert_config_file)))
if attention_probs_keep_prob is not None:
self.bert_config.attention_probs_dropout_prob = 1.0 - attention_probs_keep_prob
if hidden_keep_prob is not None:
self.bert_config.hidden_dropout_prob = 1.0 - hidden_keep_prob
self.sess_config = tf.ConfigProto(allow_soft_placement=True)
self.sess_config.gpu_options.allow_growth = True
self.sess = tf.Session(config=self.sess_config)
self._init_graph()
self._init_optimizer()
self.sess.run(tf.global_variables_initializer())
if pretrained_bert is not None:
pretrained_bert = str(expand_path(pretrained_bert))
if tf.train.checkpoint_exists(pretrained_bert) \
and not (self.load_path and tf.train.checkpoint_exists(str(self.load_path.resolve()))):
logger.info('[initializing model with Bert from {}]'.format(pretrained_bert))
# Exclude optimizer and classification variables from saved variables
var_list = self._get_saveable_variables(
exclude_scopes=('Optimizer', 'learning_rate', 'momentum', 'output_weights', 'output_bias'))
saver = tf.train.Saver(var_list)
saver.restore(self.sess, pretrained_bert)
if self.load_path is not None:
self.load()
def _init_graph(self):
self._init_placeholders()
self.bert = BertModel(config=self.bert_config,
is_training=self.is_train_ph,
input_ids=self.input_ids_ph,
input_mask=self.input_masks_ph,
token_type_ids=self.token_types_ph,
use_one_hot_embeddings=False,
)
output_layer = self.bert.get_pooled_output()
hidden_size = output_layer.shape[-1].value
output_weights = tf.get_variable(
"output_weights", [self.n_classes, hidden_size],
initializer=tf.truncated_normal_initializer(stddev=0.02))
output_bias = tf.get_variable(
"output_bias", [self.n_classes], initializer=tf.zeros_initializer())
with tf.variable_scope("loss"):
output_layer = tf.nn.dropout(output_layer, keep_prob=self.keep_prob_ph)
logits = tf.matmul(output_layer, output_weights, transpose_b=True)
logits = tf.nn.bias_add(logits, output_bias)
if self.one_hot_labels:
one_hot_labels = self.y_ph
else:
one_hot_labels = tf.one_hot(self.y_ph, depth=self.n_classes, dtype=tf.float32)
self.y_predictions = tf.argmax(logits, axis=-1)
if not self.multilabel:
log_probs = tf.nn.log_softmax(logits, axis=-1)
self.y_probas = tf.nn.softmax(logits, axis=-1)
per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
self.loss = tf.reduce_mean(per_example_loss)
else:
self.y_probas = tf.nn.sigmoid(logits)
self.loss = tf.reduce_mean(
tf.nn.sigmoid_cross_entropy_with_logits(labels=one_hot_labels, logits=logits))
def _init_placeholders(self):
self.input_ids_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='ids_ph')
self.input_masks_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='masks_ph')
self.token_types_ph = tf.placeholder(shape=(None, None), dtype=tf.int32, name='token_types_ph')
if not self.one_hot_labels:
self.y_ph = tf.placeholder(shape=(None,), dtype=tf.int32, name='y_ph')
else:
self.y_ph = tf.placeholder(shape=(None, self.n_classes), dtype=tf.float32, name='y_ph')
self.learning_rate_ph = tf.placeholder_with_default(0.0, shape=[], name='learning_rate_ph')
self.keep_prob_ph = tf.placeholder_with_default(1.0, shape=[], name='keep_prob_ph')
self.is_train_ph = tf.placeholder_with_default(False, shape=[], name='is_train_ph')
def _init_optimizer(self):
with tf.variable_scope('Optimizer'):
self.global_step = tf.get_variable('global_step', shape=[], dtype=tf.int32,
initializer=tf.constant_initializer(0), trainable=False)
# default optimizer for Bert is Adam with fixed L2 regularization
if self.optimizer is None:
self.train_op = self.get_train_op(self.loss, learning_rate=self.learning_rate_ph,
optimizer=AdamWeightDecayOptimizer,
weight_decay_rate=self.weight_decay_rate,
beta_1=0.9,
beta_2=0.999,
epsilon=1e-6,
exclude_from_weight_decay=["LayerNorm", "layer_norm", "bias"]
)
else:
self.train_op = self.get_train_op(self.loss, learning_rate=self.learning_rate_ph)
if self.optimizer is None:
new_global_step = self.global_step + 1
self.train_op = tf.group(self.train_op, [self.global_step.assign(new_global_step)])
def _build_feed_dict(self, input_ids, input_masks, token_types, y=None):
feed_dict = {
self.input_ids_ph: input_ids,
self.input_masks_ph: input_masks,
self.token_types_ph: token_types,
}
if y is not None:
feed_dict.update({
self.y_ph: y,
self.learning_rate_ph: max(self.get_learning_rate(), self.min_learning_rate),
self.keep_prob_ph: self.keep_prob,
self.is_train_ph: True,
})
return feed_dict
[docs] def train_on_batch(self, features: List[InputFeatures], y: Union[List[int], List[List[int]]]) -> Dict:
"""Train model on given batch.
This method calls train_op using features and y (labels).
Args:
features: batch of InputFeatures
y: batch of labels (class id or one-hot encoding)
Returns:
dict with loss and learning_rate values
"""
input_ids = [f.input_ids for f in features]
input_masks = [f.input_mask for f in features]
input_type_ids = [f.input_type_ids for f in features]
feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids, y)
_, loss = self.sess.run([self.train_op, self.loss], feed_dict=feed_dict)
return {'loss': loss, 'learning_rate': feed_dict[self.learning_rate_ph]}
[docs] def __call__(self, features: List[InputFeatures]) -> Union[List[int], List[List[float]]]:
"""Make prediction for given features (texts).
Args:
features: batch of InputFeatures
Returns:
predicted classes or probabilities of each class
"""
input_ids = [f.input_ids for f in features]
input_masks = [f.input_mask for f in features]
input_type_ids = [f.input_type_ids for f in features]
feed_dict = self._build_feed_dict(input_ids, input_masks, input_type_ids)
if not self.return_probas:
pred = self.sess.run(self.y_predictions, feed_dict=feed_dict)
else:
pred = self.sess.run(self.y_probas, feed_dict=feed_dict)
return pred