Source code for deeppavlov.models.torch_bert.torch_transformers_sequence_tagger
# Copyright 2019 Neural Networks and Deep Learning lab, MIPT
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
from pathlib import Path
from typing import List, Union, Dict, Optional, Tuple
import numpy as np
import torch
from overrides import overrides
from transformers import AutoModelForTokenClassification, AutoConfig
from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.errors import ConfigError
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.torch_model import TorchModel
from deeppavlov.models.torch_bert.crf import CRF
log = getLogger(__name__)
def token_from_subtoken(units: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
""" Assemble token level units from subtoken level units
Args:
units: torch.Tensor of shape [batch_size, SUBTOKEN_seq_length, n_features]
mask: mask of token beginnings. For example: for tokens
[[``[CLS]`` ``My``, ``capybara``, ``[SEP]``],
[``[CLS]`` ``Your``, ``aar``, ``##dvark``, ``is``, ``awesome``, ``[SEP]``]]
the mask will be
[[0, 1, 1, 0, 0, 0, 0],
[0, 1, 1, 0, 1, 1, 0]]
Returns:
word_level_units: Units assembled from ones in the mask. For the
example above this units will correspond to the following
[[``My``, ``capybara``],
[``Your`, ``aar``, ``is``, ``awesome``,]]
the shape of this tensor will be [batch_size, TOKEN_seq_length, n_features]
"""
shape = units.size()
batch_size = shape[0]
nf = shape[2]
nf_int = units.size()[-1]
token_seq_lengths = torch.sum(mask, 1).to(torch.int64)
n_words = torch.sum(token_seq_lengths)
max_token_seq_len = torch.max(token_seq_lengths)
idxs = torch.stack(torch.nonzero(mask, as_tuple=True), dim=1)
sample_ids_in_batch = torch.nn.functional.pad(input=idxs[:, 0], pad=[1, 0])
a = torch.logical_not(torch.eq(sample_ids_in_batch[1:], sample_ids_in_batch[:-1]).to(torch.int64))
q = a * torch.arange(n_words).to(torch.int64)
count_to_substract = torch.nn.functional.pad(torch.masked_select(q, q.to(torch.bool)), [1, 0])
new_word_indices = torch.arange(n_words).to(torch.int64) - torch.gather(
count_to_substract, dim=0, index=torch.cumsum(a, 0))
n_total_word_elements = (batch_size * max_token_seq_len).to(torch.int32)
word_indices_flat = (idxs[:, 0] * max_token_seq_len + new_word_indices).to(torch.int64)
x_mask = torch.sum(torch.nn.functional.one_hot(word_indices_flat, n_total_word_elements), 0)
x_mask = x_mask.to(torch.bool)
full_range = torch.arange(batch_size * max_token_seq_len).to(torch.int64)
nonword_indices_flat = torch.masked_select(full_range, torch.logical_not(x_mask))
def gather_nd(params, indices):
assert type(indices) == torch.Tensor
return params[indices.transpose(0, 1).long().numpy().tolist()]
elements = gather_nd(units, idxs)
sh = tuple(torch.stack([torch.sum(max_token_seq_len - token_seq_lengths), torch.tensor(nf)], 0).numpy())
paddings = torch.zeros(sh, dtype=torch.float64)
def dynamic_stitch(indices, data):
# https://discuss.pytorch.org/t/equivalent-of-tf-dynamic-partition/53735/2
n = sum(idx.numel() for idx in indices)
res = [None] * n
for i, data_ in enumerate(data):
idx = indices[i].view(-1)
if idx.numel() > 0:
d = data_.view(idx.numel(), -1)
k = 0
for idx_ in idx:
res[idx_] = d[k].to(torch.float64)
k += 1
return res
tensor_flat = torch.stack(dynamic_stitch([word_indices_flat, nonword_indices_flat], [elements, paddings]))
tensor = torch.reshape(tensor_flat, (batch_size, max_token_seq_len.item(), nf_int))
return tensor
def token_labels_to_subtoken_labels(labels, y_mask, input_mask):
subtoken_labels = []
labels_ind = 0
n_tokens_with_special = int(np.sum(input_mask))
for el in y_mask[1:n_tokens_with_special - 1]:
if el == 1:
subtoken_labels += [labels[labels_ind]]
labels_ind += 1
else:
subtoken_labels += [labels[labels_ind - 1]]
subtoken_labels = [0] + subtoken_labels + [0] * (len(input_mask) - n_tokens_with_special + 1)
return subtoken_labels
[docs]@register('torch_transformers_sequence_tagger')
class TorchTransformersSequenceTagger(TorchModel):
"""Transformer-based model on PyTorch for text tagging. It predicts a label for every token (not subtoken)
in the text. You can use it for sequence labeling tasks, such as morphological tagging or named entity recognition.
Args:
n_tags: number of distinct tags
pretrained_bert: pretrained Bert checkpoint path or key title (e.g. "bert-base-uncased")
bert_config_file: path to Bert configuration file, or None, if `pretrained_bert` is a string name
attention_probs_keep_prob: keep_prob for Bert self-attention layers
hidden_keep_prob: keep_prob for Bert hidden layers
optimizer: optimizer name from `torch.optim`
optimizer_parameters: dictionary with optimizer's parameters,
e.g. {'lr': 0.1, 'weight_decay': 0.001, 'momentum': 0.9}
learning_rate_drop_patience: how many validations with no improvements to wait
learning_rate_drop_div: the divider of the learning rate after `learning_rate_drop_patience` unsuccessful
validations
load_before_drop: whether to load best model before dropping learning rate or not
clip_norm: clip gradients by norm
min_learning_rate: min value of learning rate if learning rate decay is used
use_crf: whether to use Conditional Ramdom Field to decode tags
"""
def __init__(self,
n_tags: int,
pretrained_bert: str,
bert_config_file: Optional[str] = None,
attention_probs_keep_prob: Optional[float] = None,
hidden_keep_prob: Optional[float] = None,
optimizer: str = "AdamW",
optimizer_parameters: dict = {"lr": 1e-3, "weight_decay": 1e-6},
learning_rate_drop_patience: int = 20,
learning_rate_drop_div: float = 2.0,
load_before_drop: bool = True,
clip_norm: Optional[float] = None,
min_learning_rate: float = 1e-07,
use_crf: bool = False,
**kwargs) -> None:
self.n_classes = n_tags
self.attention_probs_keep_prob = attention_probs_keep_prob
self.hidden_keep_prob = hidden_keep_prob
self.clip_norm = clip_norm
self.pretrained_bert = pretrained_bert
self.bert_config_file = bert_config_file
self.use_crf = use_crf
super().__init__(optimizer=optimizer,
optimizer_parameters=optimizer_parameters,
learning_rate_drop_patience=learning_rate_drop_patience,
learning_rate_drop_div=learning_rate_drop_div,
load_before_drop=load_before_drop,
min_learning_rate=min_learning_rate,
**kwargs)
[docs] def train_on_batch(self,
input_ids: Union[List[List[int]], np.ndarray],
input_masks: Union[List[List[int]], np.ndarray],
y_masks: Union[List[List[int]], np.ndarray],
y: List[List[int]],
*args, **kwargs) -> Dict[str, float]:
"""
Args:
input_ids: batch of indices of subwords
input_masks: batch of masks which determine what should be attended
args: arguments passed to _build_feed_dict
and corresponding to additional input
and output tensors of the derived class.
kwargs: keyword arguments passed to _build_feed_dict
and corresponding to additional input
and output tensors of the derived class.
Returns:
dict with fields 'loss', 'head_learning_rate', and 'bert_learning_rate'
"""
b_input_ids = torch.from_numpy(input_ids).to(self.device)
b_input_masks = torch.from_numpy(input_masks).to(self.device)
subtoken_labels = [token_labels_to_subtoken_labels(y_el, y_mask, input_mask)
for y_el, y_mask, input_mask in zip(y, y_masks, input_masks)]
b_labels = torch.from_numpy(np.array(subtoken_labels)).to(torch.int64).to(self.device)
self.optimizer.zero_grad()
loss = self.model(input_ids=b_input_ids,
attention_mask=b_input_masks,
labels=b_labels).loss
loss.backward()
if self.use_crf:
self.crf(y, y_masks)
# Clip the norm of the gradients to 1.0.
# This is to help prevent the "exploding gradients" problem.
if self.clip_norm:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_norm)
self.optimizer.step()
if self.lr_scheduler is not None:
self.lr_scheduler.step()
return {'loss': loss.item()}
[docs] def __call__(self,
input_ids: Union[List[List[int]], np.ndarray],
input_masks: Union[List[List[int]], np.ndarray],
y_masks: Union[List[List[int]], np.ndarray]) -> Tuple[List[List[int]], List[np.ndarray]]:
""" Predicts tag indices for a given subword tokens batch
Args:
input_ids: indices of the subwords
input_masks: mask that determines where to attend and where not to
y_masks: mask which determines the first subword units in the the word
Returns:
Label indices or class probabilities for each token (not subtoken)
"""
b_input_ids = torch.from_numpy(input_ids).to(self.device)
b_input_masks = torch.from_numpy(input_masks).to(self.device)
with torch.no_grad():
# Forward pass, calculate logit predictions
logits = self.model(b_input_ids, attention_mask=b_input_masks)
# Move logits and labels to CPU and to numpy arrays
logits = token_from_subtoken(logits[0].detach().cpu(), torch.from_numpy(y_masks))
probas = torch.nn.functional.softmax(logits, dim=-1)
probas = probas.detach().cpu().numpy()
if self.use_crf:
logits = logits.transpose(1, 0).to(self.device)
pred = self.crf.decode(logits)
else:
logits = logits.detach().cpu().numpy()
pred = np.argmax(logits, axis=-1)
seq_lengths = np.sum(y_masks, axis=1)
pred = [p[:l] for l, p in zip(seq_lengths, pred)]
return pred, probas
@overrides
def load(self, fname=None):
if fname is not None:
self.load_path = fname
if self.pretrained_bert:
config = AutoConfig.from_pretrained(self.pretrained_bert, num_labels=self.n_classes,
output_attentions=False, output_hidden_states=False)
self.model = AutoModelForTokenClassification.from_pretrained(self.pretrained_bert, config=config)
elif self.bert_config_file and Path(self.bert_config_file).is_file():
self.bert_config = AutoConfig.from_json_file(str(expand_path(self.bert_config_file)))
if self.attention_probs_keep_prob is not None:
self.bert_config.attention_probs_dropout_prob = 1.0 - self.attention_probs_keep_prob
if self.hidden_keep_prob is not None:
self.bert_config.hidden_dropout_prob = 1.0 - self.hidden_keep_prob
self.model = AutoModelForTokenClassification(config=self.bert_config)
else:
raise ConfigError("No pre-trained BERT model is given.")
self.model.to(self.device)
if self.use_crf:
self.crf = CRF(self.n_classes).to(self.device)
self.optimizer = getattr(torch.optim, self.optimizer_name)(
self.model.parameters(), **self.optimizer_parameters)
if self.lr_scheduler_name is not None:
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)(
self.optimizer, **self.lr_scheduler_parameters)
if self.load_path:
super().load()
if self.use_crf:
weights_path_crf = Path(f"{self.load_path}_crf").resolve()
weights_path_crf = weights_path_crf.with_suffix(".pth.tar")
if weights_path_crf.exists():
checkpoint = torch.load(weights_path_crf, map_location=self.device)
self.crf.load_state_dict(checkpoint["model_state_dict"], strict=False)
else:
log.warning(f"Init from scratch. Load path {weights_path_crf} does not exist.")
@overrides
def save(self, fname: Optional[str] = None, *args, **kwargs) -> None:
super().save()
if self.use_crf:
weights_path_crf = Path(f"{fname}_crf").resolve()
weights_path_crf = weights_path_crf.with_suffix(".pth.tar")
torch.save({"model_state_dict": self.crf.cpu().state_dict()}, weights_path_crf)
self.crf.to(self.device)