|
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """
- This file contains the pattern-verbalizer pairs (PVPs) for all tasks.
- """
- import copy
- import math
- import random
- import string
- from abc import ABC, abstractmethod
- from collections import defaultdict
- from typing import Tuple, List, Union, Dict
- import numpy as np
-
- from tasks.data_utils import InputExample, num_special_tokens_to_add, build_input_from_ids, build_sample, \
- build_decoder_input, build_decoder_sample
- from utils import print_rank_0
-
- FilledPattern = Tuple[List[Union[str, Tuple[str, bool]]], List[Union[str, Tuple[str, bool]]]]
-
-
- class PVP(ABC):
- """
- This class contains functions to apply patterns and verbalizers as required by PET. Each task requires its own
- custom implementation of a PVP.
- """
-
- def __init__(self, args, tokenizer, label_list, max_seq_length, pattern_id: int = 0, verbalizer_file: str = None,
- seed: int = 42, is_multi_token=False, max_segment_length=0, fast_decode: bool = False, split='train',
- num_prompt_tokens=0):
- """
- Create a new PVP.
-
- :param args: the args
- :param tokenizer: the tokenizer
- :param label_list: the list of labels
- :param max_seq_length: the maximum length of the sequence
- :param pattern_id: the pattern id to use
- :param seed: a seed to be used for generating random numbers if necessary
- :param is_multi_token: if the verbalizers contain multiple tokens
- :param fast_decode: whether to use the fast decode mode for multi-token tasks
- :param continuous_prompt: whether to use continuous prompt optimization
- """
- self.args = args
- self.tokenizer = tokenizer
- self.label_list = label_list
- self.max_seq_length = max_seq_length
- self.pattern_id = pattern_id
- self.num_prompt_tokens = num_prompt_tokens
- self.rng = random.Random(seed)
- self.num_truncated = 0
- self.fast_decode = fast_decode
- self.split = split
- self.max_dec_seq_length = 16
- self._is_multi_token = is_multi_token
- self.max_segment_length = max_segment_length
- self.task_mask = args.task_mask
- self.continuous_prompt = args.continuous_prompt
- self.prefix_prompt = args.prefix_prompt
- if self.continuous_prompt:
- print_rank_0(f"Prompt tokens in pvp {self.num_prompt_tokens} spell length {self.spell_length}")
-
- if verbalizer_file:
- self.verbalize = PVP._load_verbalizer_from_file(verbalizer_file, self.pattern_id)
-
- @property
- def is_multi_token(self):
- return self._is_multi_token
-
- @property
- def spell_length(self):
- return 0
-
- @property
- def mask(self) -> str:
- """Return the underlying LM's mask token"""
- return self.tokenizer.get_command('MASK').Id
-
- @property
- def mask_id(self) -> int:
- """Return the underlying LM's mask id"""
- return self.tokenizer.get_command('MASK').Id
-
- @property
- def max_num_verbalizers(self) -> int:
- """Return the maximum number of verbalizers across all labels"""
- return max(len(self.verbalize(label)) for label in self.label_list)
-
- @staticmethod
- def shortenable(s):
- """Return an instance of this string that is marked as shortenable"""
- return s, True
-
- @staticmethod
- def remove_final_punc(s: Union[str, Tuple[str, bool]]):
- """Remove the final punctuation mark"""
- if isinstance(s, tuple):
- return PVP.remove_final_punc(s[0]), s[1]
- return s.rstrip(string.punctuation)
-
- @staticmethod
- def lowercase_first(s: Union[str, Tuple[str, bool]]):
- """Lowercase the first character"""
- if isinstance(s, tuple):
- return PVP.lowercase_first(s[0]), s[1]
- return s[0].lower() + s[1:]
-
- @staticmethod
- def uppercase_first(s: Union[str, Tuple[str, bool]]):
- """Lowercase the first character"""
- if isinstance(s, tuple):
- return PVP.uppercase_first(s[0]), s[1]
- return s[0].upper() + s[1:]
-
- @staticmethod
- def available_patterns():
- return [0]
-
- def replace_prompt_tokens(self, parts_a, parts_b):
- if not self.continuous_prompt:
- parts_a = [part for part in parts_a if part is not None]
- parts_b = [part for part in parts_b if part is not None]
- return parts_a, parts_b
- num_prompt_tokens = self.num_prompt_tokens
- num_pos = 0
- for parts in (parts_a, parts_b):
- for part in parts:
- if part is None:
- num_pos += 1
- avg_prompt_tokens = math.ceil(num_prompt_tokens / num_pos)
- new_parts_a, new_parts_b = [], []
- for part in parts_a:
- if part is None:
- if num_prompt_tokens > 0:
- if num_prompt_tokens >= avg_prompt_tokens:
- new_parts_a.append(avg_prompt_tokens)
- num_prompt_tokens -= avg_prompt_tokens
- else:
- new_parts_a.append(num_prompt_tokens)
- num_prompt_tokens = 0
- else:
- new_parts_a.append(part)
- for part in parts_b:
- if part is None:
- if num_prompt_tokens > 0:
- if num_prompt_tokens >= avg_prompt_tokens:
- new_parts_b.append(avg_prompt_tokens)
- num_prompt_tokens -= avg_prompt_tokens
- else:
- new_parts_b.append(num_prompt_tokens)
- num_prompt_tokens = 0
- else:
- new_parts_b.append(part)
- return new_parts_a, new_parts_b
-
- def encode(self, example: InputExample, priming: bool = False, labeled: bool = False):
- """
- Encode an input example using this pattern-verbalizer pair.
-
- :param example: the input example to encode
- :param priming: whether to use this example for priming
- :param labeled: if ``priming=True``, whether the label should be appended to this example
- :return: A tuple, consisting of a list of input ids and a list of token type ids
- """
-
- if not priming:
- assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true"
-
- tokenizer = self.tokenizer
- raw_parts_a, raw_parts_b = self.get_parts(example)
-
- raw_parts_a = [x if isinstance(x, tuple) else (x, False) for x in raw_parts_a]
- prompt_id = tokenizer.num_tokens
-
- def encode_input(raw_parts):
- parts = []
- for x, s in raw_parts:
- if isinstance(x, str):
- x = tokenizer.EncodeAsIds(x)
- elif isinstance(x, int):
- x = [prompt_id] * x
- else:
- pass
- parts.append((x, s))
- return parts
-
- parts_a = encode_input(raw_parts_a)
- if self.prefix_prompt > 0:
- parts_a = [([prompt_id] * self.prefix_prompt, False)] + parts_a
-
- parts_b = None
- if raw_parts_b:
- raw_parts_b = [x if isinstance(x, tuple) else (x, False) for x in raw_parts_b]
- parts_b = encode_input(raw_parts_b)
-
- if self.is_multi_token:
- answers = self.get_answers(example)
- if example.label is not None:
- label = self.label_list.index(example.label)
- else:
- label = 0
-
- if not self.fast_decode:
- ids_list, positions_list, sep_list, mask_list, target_list, prompt_list = [], [], [], [], [], []
- segment_id_list = []
- if priming:
- answer = answers[label]
- answer_ids = get_verbalization_ids(answer, tokenizer, force_single_token=False)
- self.num_truncated += self.truncate(parts_a, parts_b, answer_ids, max_length=self.max_seq_length)
- tokens_a = [token_id for part, _ in parts_a for token_id in part]
- tokens_b = [token_id for part, _ in parts_b for token_id in part] if parts_b else None
- input_ids = tokens_a
- if tokens_b:
- input_ids += tokens_b
- if labeled:
- mask_idx = input_ids.index(self.mask_id)
- input_ids = input_ids[:mask_idx] + answer_ids + input_ids[mask_idx + 1:]
- return input_ids
- else:
- for idx, answer in enumerate(answers):
- this_parts_a, this_parts_b = copy.deepcopy(parts_a), copy.deepcopy(parts_b)
- answer_ids = get_verbalization_ids(answer, tokenizer, force_single_token=False)
- answer_ids = answer_ids + [tokenizer.get_command('eop').Id]
- self.num_truncated += self.truncate(this_parts_a, this_parts_b, answer_ids,
- max_length=self.max_seq_length)
- tokens_a = [token_id for part, _ in this_parts_a for token_id in part]
- tokens_b = [token_id for part, _ in this_parts_b for token_id in part] if parts_b else None
- if self.max_segment_length > 0:
- num_segments = (len(answer_ids) - 1) // self.max_segment_length + 1
- segments = [
- answer_ids[index * self.max_segment_length: (index + 1) * self.max_segment_length]
- for
- index in range(num_segments)]
- segment_id_list += [idx] * len(segments)
- else:
- segments = [answer_ids]
- for segment in segments:
- data = build_input_from_ids(tokens_a, tokens_b, segment, self.max_seq_length,
- self.tokenizer,
- args=self.args, add_cls=True, add_sep=False, add_piece=True,
- mask_id=self.mask_id)
- ids, types, paddings, position_ids, sep, target_ids, loss_masks = data
- prompt_pos = [idx for idx, token in enumerate(ids) if token == prompt_id]
- ids = [idx if idx != prompt_id else 0 for idx in ids]
- prompt_list.append(prompt_pos)
- ids_list.append(ids)
- positions_list.append(position_ids)
- sep_list.append(sep)
- target_list.append(target_ids)
- mask_list.append(loss_masks)
- if self.mask in tokens_a:
- mask_pos = tokens_a.index(self.mask)
- tokens_a = tokens_a[:mask_pos] + segment + tokens_a[mask_pos:]
- else:
- mask_pos = tokens_b.index(self.mask)
- tokens_b = tokens_b[:mask_pos] + segment + tokens_b[mask_pos:]
- segment_id_list = segment_id_list if segment_id_list else None
- sample = build_sample(ids_list, positions=positions_list, masks=sep_list, label=label,
- logit_mask=mask_list, target=target_list,
- unique_id=example.guid, segment_ids=segment_id_list, prompt_ids=prompt_list)
- return sample
- else:
- this_parts_a, this_parts_b = copy.deepcopy(parts_a), copy.deepcopy(parts_b)
- self.num_truncated += self.truncate(this_parts_a, this_parts_b, None, max_length=self.max_seq_length)
- tokens_a = [token_id for part, _ in this_parts_a for token_id in part]
- tokens_b = [token_id for part, _ in this_parts_b for token_id in part] if parts_b else None
- data = build_input_from_ids(tokens_a, tokens_b, None, self.max_seq_length, self.tokenizer,
- args=self.args, add_cls=True, add_sep=False, add_piece=False)
- ids, types, paddings, position_ids, sep, target_ids, loss_masks = data
- sample = build_sample(ids, positions=position_ids, masks=sep, label=label, unique_id=example.guid)
-
- ids_list, positions_list, mask_list, target_list, logit_mask_list = [], [], [], [], []
- for answer in answers:
- answer_ids = get_verbalization_ids(answer, tokenizer, force_single_token=False)
- answer_ids = answer_ids + [tokenizer.get_command('eop').Id]
- answer_ids = answer_ids[:self.max_dec_seq_length]
- data = build_decoder_input(ids, answer_ids, self.max_seq_length, self.max_dec_seq_length, tokenizer)
- dec_ids, _, _, dec_position_ids, _, dec_target_ids, dec_loss_masks = data
- ids_list.append(dec_ids)
- positions_list.append(dec_position_ids)
- mask_list.append(sep)
- target_list.append(dec_target_ids)
- logit_mask_list.append(dec_loss_masks)
-
- sample = build_decoder_sample(sample, ids_list, positions_list, mask_list, target_list, logit_mask_list)
- return sample
-
- else:
- self.num_truncated += self.truncate(parts_a, parts_b, [], max_length=self.max_seq_length)
-
- tokens_a = [token_id for part, _ in parts_a for token_id in part]
- tokens_b = [token_id for part, _ in parts_b for token_id in part] if parts_b else None
- if priming:
- input_ids = tokens_a
- if tokens_b:
- input_ids += tokens_b
- if labeled:
- mask_idx = input_ids.index(self.mask_id)
- verbalizer = self.verbalize(example.label)
- assert len(verbalizer) == 1, 'priming only supports one verbalization per label'
- verbalizer = verbalizer[0]
- verbalizer_id = get_verbalization_ids(verbalizer, self.tokenizer, force_single_token=True)
- input_ids[mask_idx] = verbalizer_id
- return input_ids
- data = build_input_from_ids(tokens_a, tokens_b, None, self.max_seq_length, self.tokenizer, args=self.args,
- add_cls=True, add_sep=False, add_piece=True)
- ids, types, paddings, position_ids, sep, target_ids, loss_masks = data
- prompt_pos = [idx for idx, token in enumerate(ids) if token == prompt_id]
- ids = [token if token != prompt_id else 0 for token in ids]
- target_ids = self.get_verbalizer_ids()
- if example.label is not None:
- label = self.label_list.index(example.label)
- else:
- label = 0
- sample = build_sample(ids=ids, positions=position_ids, target=target_ids, masks=sep, logit_mask=loss_masks,
- label=label, unique_id=example.guid, prompt_ids=prompt_pos)
- return sample
-
- @staticmethod
- def _seq_length(parts: List[Tuple[List[int], bool]], only_shortenable: bool = False):
- return sum([len(x) for x, shortenable in parts if not only_shortenable or shortenable]) if parts else 0
-
- @staticmethod
- def _remove_last(parts: List[Tuple[List[int], bool]]):
- last_idx = max(idx for idx, (seq, shortenable) in enumerate(parts) if shortenable and seq)
- parts[last_idx] = (parts[last_idx][0][:-1], parts[last_idx][1])
-
- def truncate(self, parts_a: List[Tuple[List[int], bool]], parts_b: List[Tuple[List[int], bool]], answer: List[int],
- max_length: int):
- """Truncate two sequences of text to a predefined total maximum length"""
- total_len = self._seq_length(parts_a) + self._seq_length(parts_b)
- if answer:
- total_len += len(answer)
- total_len += num_special_tokens_to_add(parts_a, parts_b, answer, add_cls=True, add_sep=False, add_piece=True)
- num_tokens_to_remove = total_len - max_length
-
- if num_tokens_to_remove <= 0:
- return False
-
- for _ in range(num_tokens_to_remove):
- if self._seq_length(parts_a, only_shortenable=True) > self._seq_length(parts_b, only_shortenable=True):
- self._remove_last(parts_a)
- else:
- self._remove_last(parts_b)
- return True
-
- @abstractmethod
- def get_parts(self, example: InputExample) -> FilledPattern:
- """
- Given an input example, apply a pattern to obtain two text sequences (text_a and text_b) containing exactly one
- mask token (or one consecutive sequence of mask tokens for PET with multiple masks). If a task requires only a
- single sequence of text, the second sequence should be an empty list.
-
- :param example: the input example to process
- :return: Two sequences of text. All text segments can optionally be marked as being shortenable.
- """
- pass
-
- def get_answers(self, example: InputExample):
- return [self.verbalize(label)[0] for label in self.label_list]
-
- def get_verbalizer_ids(self):
- target_ids = []
- for label in self.label_list:
- verbalizer = self.verbalize(label)[0]
- verbalizer_id = get_verbalization_ids(verbalizer, self.tokenizer, force_single_token=True)
- target_ids.append(verbalizer_id)
- return target_ids
-
- @abstractmethod
- def verbalize(self, label) -> List[str]:
- """
- Return all verbalizations for a given label.
-
- :param label: the label
- :return: the list of verbalizations
- """
- pass
-
- def get_mask_positions(self, input_ids: List[int]) -> List[int]:
- label_idx = input_ids.index(self.mask_id)
- labels = [-1] * len(input_ids)
- labels[label_idx] = 1
- return labels
-
- @staticmethod
- def _load_verbalizer_from_file(path: str, pattern_id: int):
-
- verbalizers = defaultdict(dict) # type: Dict[int, Dict[str, List[str]]]
- current_pattern_id = None
-
- with open(path, 'r') as fh:
- for line in fh.read().splitlines():
- if line.isdigit():
- current_pattern_id = int(line)
- elif line:
- label, *realizations = line.split()
- verbalizers[current_pattern_id][label] = realizations
-
- print_rank_0("Automatically loaded the following verbalizer: \n {}".format(verbalizers[pattern_id]))
-
- def verbalize(label) -> List[str]:
- return verbalizers[pattern_id][label]
-
- return verbalize
-
-
- class CopaPVP(PVP):
- @staticmethod
- def available_patterns():
- return [0, 1]
-
- @property
- def is_multi_token(self):
- return True
-
- @property
- def spell_length(self):
- return self.num_prompt_tokens + self.prefix_prompt
-
- @property
- def mask(self) -> str:
- """Return the underlying LM's mask token"""
- mask_token = 'MASK'
- return self.tokenizer.get_command(mask_token).Id
-
- @property
- def mask_id(self) -> int:
- """Return the underlying LM's mask id"""
- mask_token = 'MASK'
- return self.tokenizer.get_command(mask_token).Id
-
- def get_answers(self, example: InputExample):
- choice1 = " " + self.remove_final_punc(self.lowercase_first(example.meta['choice1']))
- choice2 = " " + self.remove_final_punc(self.lowercase_first(example.meta['choice2']))
- return [choice1, choice2]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- assert self.pattern_id in [0, 1, 2, 3]
- premise = self.remove_final_punc(self.shortenable(" " + example.text_a))
- choice1 = self.remove_final_punc(self.lowercase_first(example.meta['choice1']))
- choice2 = self.remove_final_punc(self.lowercase_first(example.meta['choice2']))
-
- question = example.meta['question']
- assert question in ['cause', 'effect']
- if question == 'cause':
- joiner = ' because'
- else:
- joiner = ', so'
- if self.pattern_id == 0:
- parts_a, parts_b = [None, '"', choice1, '" or "', choice2, '"?', None, premise, joiner, None, [self.mask],
- '.'], []
- elif self.pattern_id == 1:
- parts_a, parts_b = [None, choice1, ' or', " " + choice2, '?', None, premise, joiner, None, [self.mask],
- '.'], []
- elif self.pattern_id == 2:
- parts_a, parts_b = [None, '"', choice1, '" or "', choice2, '"', None, premise, joiner, [self.mask], '.',
- None], []
- else:
- raise NotImplementedError(self.pattern_id)
- parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b)
- return parts_a, parts_b
-
- def verbalize(self, label) -> List[str]:
- return []
-
- def encode(self, example: InputExample, priming: bool = False, labeled: bool = False):
- """
- Encode an input example using this pattern-verbalizer pair.
-
- :param example: the input example to encode
- :param priming: whether to use this example for priming
- :param labeled: if ``priming=True``, whether the label should be appended to this example
- :return: A tuple, consisting of a list of input ids and a list of token type ids
- """
- if self.continuous_prompt or self.pattern_id < 2:
- return super().encode(example, priming=priming, labeled=labeled)
- if not priming:
- assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true"
-
- tokenizer = self.tokenizer
- premise = self.remove_final_punc(self.shortenable(example.text_a))
- choice1 = " " + self.remove_final_punc(self.lowercase_first(example.meta['choice1']))
- choice2 = " " + self.remove_final_punc(self.lowercase_first(example.meta['choice2']))
- question = example.meta['question']
- assert question in ['cause', 'effect']
- answer = " because" if question == 'cause' else " so"
- answer_ids = [get_verbalization_ids(answer, tokenizer, force_single_token=True)]
- if self.is_multi_token:
- answer_ids.append(tokenizer.get_command('eop').Id)
-
- ids_list, positions_list, sep_list, mask_list, target_list = [], [], [], [], []
-
- for choice in [choice1, choice2]:
- parts = ['"', choice1[1:], '" or "', choice2[1:], '"?', premise, [self.mask], choice]
- parts = [x if isinstance(x, tuple) else (x, False) for x in parts]
- parts = [(tokenizer.EncodeAsIds(x).tokenization if isinstance(x, str) else x, s) for x, s in parts if
- x]
- self.num_truncated += self.truncate(parts, None, answer_ids, max_length=self.max_seq_length)
- tokens_a = [token_id for part, _ in parts for token_id in part]
- data = build_input_from_ids(tokens_a, None, answer_ids, self.max_seq_length, self.tokenizer, args=self.args,
- add_cls=True, add_sep=False, add_piece=True)
- ids, types, paddings, position_ids, sep, target_ids, loss_masks = data
- ids_list.append(ids)
- positions_list.append(position_ids)
- sep_list.append(sep)
- target_list.append(target_ids)
- mask_list.append(loss_masks)
- if example.label is not None:
- label = self.label_list.index(example.label)
- else:
- label = 0
- sample = build_sample(ids_list, positions=positions_list, masks=sep_list, label=label,
- logit_mask=mask_list, target=target_list,
- unique_id=example.guid)
- return sample
-
-
- class WscPVP(PVP):
- @staticmethod
- def available_patterns():
- return [0, 1, 2]
-
- @property
- def is_multi_token(self):
- return True
-
- @property
- def spell_length(self):
- return self.num_prompt_tokens + self.prefix_prompt
-
- def get_answers(self, example: InputExample):
- target = " " + example.meta['span1_text']
- answers = [target]
- if 'candidates' in example.meta:
- candidates = example.meta['candidates']
- # if len(candidates) > 10:
- # random.shuffle(candidates)
- # candidates = candidates[:10]
- answers += [" " + cand for cand in candidates]
- return answers
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- pronoun = example.meta['span2_text']
- pronoun_idx = example.meta['span2_index']
-
- words_a = example.text_a.split()
- words_a[pronoun_idx] = '*' + words_a[pronoun_idx] + '*'
- text_a = ' '.join(words_a)
- text_a = self.shortenable(text_a)
-
- if self.pattern_id == 0:
- parts_a, parts_b = [None, text_a, None, " The pronoun '*" + pronoun + "*' refers to", None, [self.mask],
- '.'], []
- elif self.pattern_id == 1:
- parts_a, parts_b = [None, text_a, None,
- " In the previous sentence, the pronoun '*" + pronoun + "*' refers to", None,
- [self.mask], '.'], []
- elif self.pattern_id == 2:
- parts_a, parts_b = [None, text_a, None,
- " Question: In the passage above, what does the pronoun '*" + pronoun + "*' refer to?",
- None,
- " Answer:", [self.mask], '.'], []
- else:
- raise NotImplementedError(self.pattern_id)
- parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b)
- return parts_a, parts_b
-
- def encode(self, example: InputExample, priming: bool = False, labeled: bool = False):
- """
- Encode an input example using this pattern-verbalizer pair.
-
- :param example: the input example to encode
- :param priming: whether to use this example for priming
- :param labeled: if ``priming=True``, whether the label should be appended to this example
- :return: A tuple, consisting of a list of input ids and a list of token type ids
- """
- if self.args.loss_func in ['generative', 'mix']:
- sample = super().encode(example, priming=priming, labeled=labeled)
- if self.split == 'train':
- sample['label'] = 0
- return sample
-
- if not priming:
- assert not labeled, "'labeled' can only be set to true if 'priming' is also set to true"
-
- tokenizer = self.tokenizer
- prompt_id = tokenizer.num_tokens
- raw_parts_a, raw_parts_b = self.get_parts(example)
-
- raw_parts_a = [x if isinstance(x, tuple) else (x, False) for x in raw_parts_a]
-
- def encode_input(raw_parts):
- parts = []
- for x, s in raw_parts:
- if isinstance(x, str):
- x = tokenizer.EncodeAsIds(x)
- elif isinstance(x, int):
- x = [prompt_id] * x
- else:
- pass
- parts.append((x, s))
- return parts
-
- parts_a = encode_input(raw_parts_a)
- if self.prefix_prompt > 0:
- parts_a = [([prompt_id] * self.prefix_prompt, False)] + parts_a
- parts_b = None
- if raw_parts_b:
- raw_parts_b = [x if isinstance(x, tuple) else (x, False) for x in raw_parts_b]
- parts_b = encode_input(raw_parts_b)
- answer = self.get_answers(example)[0]
- answer_ids = get_verbalization_ids(answer, tokenizer, force_single_token=False)
- answer_ids = answer_ids + [tokenizer.get_command('eop').Id]
- self.num_truncated += self.truncate(parts_a, parts_b, answer_ids, max_length=self.max_seq_length)
- tokens_a = [token_id for part, _ in parts_a for token_id in part]
- tokens_b = [token_id for part, _ in parts_b for token_id in part] if parts_b else None
- data = build_input_from_ids(tokens_a, tokens_b, answer_ids, self.max_seq_length, self.tokenizer, args=self.args,
- add_cls=True, add_sep=False, add_piece=True)
- ids, types, paddings, position_ids, sep, target_ids, loss_masks = data
- prompt_pos = [idx for idx, token in enumerate(ids) if token == prompt_id]
- ids = [token if token != prompt_id else 0 for token in ids]
- if example.label is not None:
- label = self.label_list.index(example.label)
- else:
- label = 0
- return {'text': np.array(ids, dtype=np.int64), 'target': np.array(target_ids, dtype=np.int64),
- 'attention_mask': np.array(sep, dtype=np.int64), 'loss_mask': np.array(loss_masks, dtype=np.int64),
- "position_id": np.array(position_ids, dtype=np.int64),
- 'prompt_pos': np.array(prompt_pos, dtype=np.int64), 'label': label, 'uid': example.guid}
-
- def verbalize(self, label) -> List[str]:
- return []
-
-
- class RecordPVP(PVP):
- @property
- def is_multi_token(self):
- return True
-
- def get_answers(self, example: InputExample):
- choices = example.meta['candidates']
- choices = [" " + choice for choice in choices]
- return choices
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- premise = self.shortenable(example.text_a)
-
- assert '@placeholder' in example.text_b, f'question "{example.text_b}" does not contain a @placeholder token'
- question_a, question_b = example.text_b.split('@placeholder')
- return [premise, " " + question_a.rstrip(), [self.mask], question_b], []
-
- def verbalize(self, label) -> List[str]:
- return []
-
-
- class RacePVP(PVP):
- @property
- def is_multi_token(self):
- return True
-
- @staticmethod
- def available_patterns():
- return [0, 1]
-
- def get_answers(self, example: InputExample):
- choices = example.meta['choices']
- choices = [" " + choice for choice in choices]
- return choices
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- context = self.shortenable(example.text_a)
- question = " " + example.text_b
-
- if "_" in question:
- left, right = question.split('_', maxsplit=1)
- if self.pattern_id == 0:
- return [context], [self.shortenable(left.rstrip()), [self.mask], self.shortenable(right)]
- else:
- left = left.rstrip()
- if left:
- left = self.lowercase_first(left)
- return [context], [" Based on the previous passage,",
- self.shortenable(left), [self.mask],
- self.shortenable(right)]
- else:
- if self.pattern_id == 0:
- return [context], [" Question:", self.shortenable(question), " Answer:", [self.mask]]
- else:
- return [context], [" Based on the previous passage,", self.shortenable(question), [self.mask]]
-
- def verbalize(self, label) -> List[str]:
- return []
-
-
- class RtePVP(PVP):
- VERBALIZER = {
- "not_entailment": [" No"],
- "entailment": [" Yes"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2, 3, 4]
-
- @property
- def spell_length(self):
- return self.num_prompt_tokens + self.prefix_prompt
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- # switch text_a and text_b to get the correct order
- text_a = example.text_a
- text_b = example.text_b.rstrip(string.punctuation)
- if self.pattern_id == 0:
- parts_a, parts_b = [None, '"', self.shortenable(text_b), '" ?'], [None, [self.mask], ',', None, ' "',
- self.shortenable(text_a), '"']
- elif self.pattern_id == 1:
- parts_a, parts_b = [None, self.shortenable(text_b), '?'], [None, [self.mask], ',', None,
- self.shortenable(" " + text_a)]
- elif self.pattern_id == 2:
- parts_a, parts_b = [None, '"', self.shortenable(text_b), '" ?'], [None, [self.mask], '. "', None,
- self.shortenable(text_a), '"']
- elif self.pattern_id == 3:
- parts_a, parts_b = [None, self.shortenable(text_b), '?'], [None, [self.mask], '.', None,
- self.shortenable(" " + text_a)]
- elif self.pattern_id == 4:
- parts_a, parts_b = [None, self.shortenable(text_a), None, ' question:', self.shortenable(" " + text_b),
- ' True or False?', None, ' answer:', [self.mask]], []
- else:
- raise NotImplementedError(self.pattern_id)
- parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b)
- return parts_a, parts_b
-
- def verbalize(self, label) -> List[str]:
- if self.pattern_id == 4:
- return [' true'] if label == 'entailment' else [' false']
- return RtePVP.VERBALIZER[label]
-
-
- class CbPVP(RtePVP):
- VERBALIZER = {
- "contradiction": [" No"],
- "entailment": [" Yes"],
- "neutral": [" Maybe"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2, 3, 4]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- if self.pattern_id == 4:
- text_a = self.shortenable(example.text_a)
- text_b = self.shortenable(" " + example.text_b)
- parts_a, parts_b = [None, text_a, None, ' question:', text_b, ' true, false or neither?', None, ' answer:',
- [self.mask]], []
- parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b)
- return parts_a, parts_b
- return super().get_parts(example)
-
- def verbalize(self, label) -> List[str]:
- if self.pattern_id == 4:
- return [' true'] if label == 'entailment' else [' false'] if label == 'contradiction' else [' neither']
- return CbPVP.VERBALIZER[label]
-
-
- class BoolQPVP(PVP):
- VERBALIZER_A = {
- "false": [" No"],
- "true": [" Yes"]
- }
-
- VERBALIZER_B = {
- "false": [" false"],
- "true": [" true"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2, 3, 4, 5]
-
- @property
- def spell_length(self):
- return self.num_prompt_tokens + self.prefix_prompt
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- passage = example.text_a
- question = example.text_b
-
- if self.pattern_id < 2:
- parts_a, parts_b = [None, self.shortenable(passage), None, ' Question:', self.shortenable(" " + question),
- '? Answer:', None, [self.mask], '.'], []
- elif self.pattern_id < 4:
- parts_a, parts_b = [None, self.shortenable(passage), ' Based on the previous passage,', None,
- self.shortenable(" " + question), '?', None, [self.mask], '.'], []
- elif self.pattern_id < 6:
- parts_a, parts_b = ['Based on the following passage', None, self.shortenable(" " + question), '?', None,
- [self.mask], '.', None, self.shortenable(" " + passage)], []
- else:
- raise NotImplementedError(self.pattern_id)
- parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b)
- return parts_a, parts_b
-
- def verbalize(self, label) -> List[str]:
- if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4:
- return BoolQPVP.VERBALIZER_A[label]
- else:
- return BoolQPVP.VERBALIZER_B[label]
-
-
- class MultiRcPVP(PVP):
- VERBALIZER = {
- 0: [" No"],
- 1: [" Yes"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2, 3, 4]
-
- @property
- def spell_length(self):
- return self.num_prompt_tokens + self.prefix_prompt
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- passage = self.remove_final_punc(self.shortenable(example.text_a.rstrip()))
- question = self.remove_final_punc(example.text_b.rstrip())
- answer = example.meta['answer']
- if self.pattern_id == 0:
- parts_a, parts_b = [passage, '.', None, ' Question:', " " + question + '?', None, ' Is it', " " + answer,
- '?', None, [self.mask], '.'], []
- elif self.pattern_id == 1:
- parts_a, parts_b = [passage, '.', None, ' Question:', " " + question, '?', None, ' Is the correct answer "',
- answer, '"?', None, [self.mask], '.'], []
- elif self.pattern_id == 2:
- parts_a, parts_b = [passage, '. Based on the previous passage,', None, " " + question, '?', None, ' Is "',
- answer, '" a correct answer?', None, [self.mask], '.'], []
- elif self.pattern_id == 3:
- parts_a, parts_b = [None, passage, None, " " + question, '- [', [self.mask], ']', None, answer], []
- elif self.pattern_id == 4:
- parts_a, parts_b = [passage, '.', None, ' Question:', " " + question, '?', None, " " + answer, '?', None,
- [self.mask], '.'], []
- else:
- raise NotImplementedError(self.pattern_id)
- parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b)
- return parts_a, parts_b
-
- def verbalize(self, label) -> List[str]:
- if self.pattern_id == 3:
- return [' False'] if label == 0 else [' True']
- return MultiRcPVP.VERBALIZER[label]
-
-
- class WicPVP(PVP):
- VERBALIZER_A = {
- "false": [" No"],
- "true": [" Yes"]
- }
- VERBALIZER_B = {
- "false": ["2"],
- "true": ["b"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2]
-
- @property
- def spell_length(self):
- return self.num_prompt_tokens + self.prefix_prompt
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- text_a = example.text_a
- text_b = example.text_b
- word = example.meta['word']
-
- if self.pattern_id == 0:
- parts_a, parts_b = [None, self.shortenable('"' + text_a + '" / "' + text_b + '"'), None,
- ' Similar sense of "' + word + '"?', None, [self.mask], '.'], []
- elif self.pattern_id == 1:
- parts_a, parts_b = [self.shortenable(text_a), None, self.shortenable(" " + text_b), None,
- ' Does ' + word + ' have the same meaning in both sentences?', None, [self.mask]], []
- elif self.pattern_id == 2:
- parts_a, parts_b = [None, word, ' .', None, ' Sense (1) (a) "', self.shortenable(text_a), '"', None, ' (',
- [self.mask], ') "', text_b, '"'], []
- else:
- raise NotImplementedError(self.pattern_id)
- parts_a, parts_b = self.replace_prompt_tokens(parts_a, parts_b)
- return parts_a, parts_b
-
- def verbalize(self, label) -> List[str]:
- if self.pattern_id == 2:
- return WicPVP.VERBALIZER_B[label]
- return WicPVP.VERBALIZER_A[label]
-
-
- class AgnewsPVP(PVP):
- VERBALIZER = {
- "1": [" World"],
- "2": [" Sports"],
- "3": [" Business"],
- "4": [" Tech"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2, 3, 4, 5]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
-
- text_a = self.shortenable(example.text_a)
- text_b = self.shortenable(example.text_b)
-
- if self.pattern_id == 0:
- return [[self.mask], ':', text_a, text_b], []
- elif self.pattern_id == 1:
- return [[self.mask], ' News:', text_a, text_b], []
- elif self.pattern_id == 2:
- return [text_a, '(', [self.mask], ')', text_b], []
- elif self.pattern_id == 3:
- return [text_a, text_b, '(', [self.mask], ')'], []
- elif self.pattern_id == 4:
- return ['[ Category:', [self.mask], ']', text_a, text_b], []
- elif self.pattern_id == 5:
- return [[self.mask], '-', text_a, text_b], []
- else:
- raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
-
- def verbalize(self, label) -> List[str]:
- return AgnewsPVP.VERBALIZER[label]
-
-
- class YahooPVP(PVP):
- VERBALIZER = {
- "1": [" Society"],
- "2": [" Science"],
- "3": [" Health"],
- "4": [" Education"],
- "5": [" Computer"],
- "6": [" Sports"],
- "7": [" Business"],
- "8": [" Entertainment"],
- "9": [" Relationship"],
- "10": [" Politics"],
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2, 3, 4, 5]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
-
- text_a = self.shortenable(example.text_a)
- text_b = self.shortenable(example.text_b)
-
- if self.pattern_id == 0:
- return [[self.mask], ':', text_a, text_b], []
- elif self.pattern_id == 1:
- return [[self.mask], ' Question:', text_a, text_b], []
- elif self.pattern_id == 2:
- return [text_a, '(', [self.mask], ')', text_b], []
- elif self.pattern_id == 3:
- return [text_a, text_b, '(', [self.mask], ')'], []
- elif self.pattern_id == 4:
- return ['[ Category:', [self.mask], ']', text_a, text_b], []
- elif self.pattern_id == 5:
- return [[self.mask], '-', text_a, text_b], []
- else:
- raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
-
- def verbalize(self, label) -> List[str]:
- return YahooPVP.VERBALIZER[label]
-
-
- class MnliPVP(PVP):
- VERBALIZER_A = {
- "contradiction": [" Wrong"],
- "entailment": [" Right"],
- "neutral": [" Maybe"]
- }
- VERBALIZER_B = {
- "contradiction": [" No"],
- "entailment": [" Yes"],
- "neutral": [" Maybe"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2, 3]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- text_a = self.shortenable(self.remove_final_punc(example.text_a))
- text_b = self.shortenable(example.text_b)
-
- if self.pattern_id == 0 or self.pattern_id == 2:
- return ['"', text_a, '" ?'], [[self.mask], ', "', text_b, '"']
- elif self.pattern_id == 1 or self.pattern_id == 3:
- return [text_a, '?'], [[self.mask], ',', text_b]
-
- def verbalize(self, label) -> List[str]:
- if self.pattern_id == 0 or self.pattern_id == 1:
- return MnliPVP.VERBALIZER_A[label]
- return MnliPVP.VERBALIZER_B[label]
-
-
- class YelpPolarityPVP(PVP):
- VERBALIZER = {
- "1": [" bad"],
- "2": [" good"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2, 3]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- text = self.shortenable(example.text_a)
-
- if self.pattern_id == 0:
- return ['It was', [self.mask], '.', text], []
- elif self.pattern_id == 1:
- return [text, '. All in all, it was', [self.mask], '.'], []
- elif self.pattern_id == 2:
- return ['Just', [self.mask], "!"], [text]
- elif self.pattern_id == 3:
- return [text], [' In summary, the restaurant is', [self.mask], '.']
- else:
- raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
-
- def verbalize(self, label) -> List[str]:
- return YelpPolarityPVP.VERBALIZER[label]
-
-
- class YelpFullPVP(YelpPolarityPVP):
- VERBALIZER = {
- "1": [" terrible"],
- "2": [" bad"],
- "3": [" okay"],
- "4": [" good"],
- "5": [" great"]
- }
-
- def verbalize(self, label) -> List[str]:
- return YelpFullPVP.VERBALIZER[label]
-
-
- class XStancePVP(PVP):
- VERBALIZERS = {
- 'en': {"FAVOR": ["Yes"], "AGAINST": ["No"]},
- 'de': {"FAVOR": ["Ja"], "AGAINST": ["Nein"]},
- 'fr': {"FAVOR": ["Oui"], "AGAINST": ["Non"]}
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2, 3, 4, 5]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
-
- text_a = self.shortenable(example.text_a)
- text_b = self.shortenable(example.text_b)
-
- if self.pattern_id == 0 or self.pattern_id == 2 or self.pattern_id == 4:
- return ['"', text_a, '"'], [[self.mask], '. "', text_b, '"']
- elif self.pattern_id == 1 or self.pattern_id == 3 or self.pattern_id == 5:
- return [text_a], [[self.mask], '.', text_b]
-
- def verbalize(self, label) -> List[str]:
- lang = 'de' if self.pattern_id < 2 else 'en' if self.pattern_id < 4 else 'fr'
- return XStancePVP.VERBALIZERS[lang][label]
-
-
- class Sst2PVP(PVP):
- VERBALIZER_A = {
- "0": [" terrible"],
- "1": [" great"]
- }
-
- VERBALIZER_B = {
- "0": [" bad"],
- "1": [" good"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- text = self.shortenable(example.text_a)
- if self.pattern_id == 0 or self.pattern_id == 1:
- return [text, ' It was', [self.mask], '.'], []
- else:
- raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
-
- def verbalize(self, label) -> List[str]:
- if self.pattern_id == 0:
- return Sst2PVP.VERBALIZER_A[label]
- else:
- return Sst2PVP.VERBALIZER_B[label]
-
-
- class ColaPVP(PVP):
- VERBALIZER = {
- "0": [" incorrect"],
- "1": [" correct"]
- }
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- text = self.shortenable(example.text_a)
- if self.pattern_id == 0:
- return ['"', text, '"', " This is", [self.mask], '.'], []
- else:
- raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
-
- def verbalize(self, label) -> List[str]:
- return ColaPVP.VERBALIZER[label]
-
-
- class MrpcPVP(PVP):
- VERBALIZER = {
- "0": [" No"],
- "1": [" Yes"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- text_a = self.shortenable(example.text_a)
- if self.pattern_id == 0:
- text_b = self.shortenable(self.lowercase_first(example.text_b))
- return [text_a], [[self.mask], ', ', text_b]
- elif self.pattern_id == 1:
- text_b = self.shortenable(self.remove_final_punc(self.lowercase_first(example.text_b)))
- return [text_a], [" Does it mean that", text_b, "?", [self.mask]]
- else:
- raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
-
- def verbalize(self, label) -> List[str]:
- return MrpcPVP.VERBALIZER[label]
-
-
- class QqpPVP(PVP):
- VERBALIZER = {
- "0": [" No"],
- "1": [" Yes"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- text_a = self.shortenable(example.text_a)
- text_b = self.shortenable(self.lowercase_first(example.text_b))
- if self.pattern_id == 0:
- return [text_a], [" Do you mean ", text_b, [self.mask], "."]
- elif self.pattern_id == 1:
- return [text_a], [[self.mask], ', ', text_b]
- else:
- raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
-
- def verbalize(self, label) -> List[str]:
- return QqpPVP.VERBALIZER[label]
-
-
- class QnliPVP(PVP):
- VERBALIZER = {
- "not_entailment": [" No"],
- "entailment": [" Yes"]
- }
-
- @staticmethod
- def available_patterns():
- return [0, 1, 2]
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- question = self.remove_final_punc(example.text_a)
- passage = example.text_b
- if self.pattern_id == 0:
- return [self.shortenable(passage), ' Question:', self.shortenable(" " + question),
- '? Do you know the answer?', [self.mask], '.'], []
- elif self.pattern_id == 1:
- return [self.shortenable(passage), ' Based on the previous passage, do you know the answer',
- self.shortenable(" " + question),
- '?', [self.mask], '.'], []
- elif self.pattern_id == 2:
- return ['Based on the following passage, do you know the answer', self.shortenable(" " + question), '?',
- [self.mask], '.', self.shortenable(" " + passage)], []
- else:
- raise ValueError("No pattern implemented for id {}".format(self.pattern_id))
-
- def verbalize(self, label) -> List[str]:
- return QnliPVP.VERBALIZER[label]
-
-
- class SquadPVP(PVP):
- @property
- def is_multi_token(self):
- return True
-
- def get_answers(self, example: InputExample):
- target = " " + example.meta['answer']['text']
- answers = [target]
- return answers
-
- def get_parts(self, example: InputExample) -> FilledPattern:
- context = self.shortenable(example.text_a)
- question = example.text_b
- return [context, " " + question, [self.mask], "."], []
-
- def verbalize(self, label) -> List[str]:
- return []
-
-
- def get_verbalization_ids(word: str, tokenizer, force_single_token: bool) -> Union[int, List[int]]:
- """
- Get the token ids corresponding to a verbalization
-
- :param word: the verbalization
- :param tokenizer: the tokenizer to use
- :param force_single_token: whether it should be enforced that the verbalization corresponds to a single token.
- If set to true, this method returns a single int instead of a list and throws an error if the word
- corresponds to multiple tokens.
- :return: either the list of token ids or the single token id corresponding to this word
- """
- ids = tokenizer.EncodeAsIds(word).tokenization
- if not force_single_token:
- return ids
- assert len(ids) == 1, \
- f'Verbalization "{word}" does not correspond to a single token, got {tokenizer.DecodeIds(ids)}'
- verbalization_id = ids[0]
- assert verbalization_id not in tokenizer.command_id_map, \
- f'Verbalization {word} is mapped to a special token {tokenizer.IdToToken(verbalization_id)}'
- return verbalization_id
-
-
- PVPS = {
- 'agnews': AgnewsPVP,
- 'mnli': MnliPVP,
- 'yelp-polarity': YelpPolarityPVP,
- 'yelp-full': YelpFullPVP,
- 'yahoo': YahooPVP,
- 'xstance': XStancePVP,
- 'xstance-de': XStancePVP,
- 'xstance-fr': XStancePVP,
- 'rte': RtePVP,
- 'wic': WicPVP,
- 'cb': CbPVP,
- 'wsc': WscPVP,
- 'boolq': BoolQPVP,
- 'copa': CopaPVP,
- 'multirc': MultiRcPVP,
- 'record': RecordPVP,
- 'ax-b': RtePVP,
- 'ax-g': RtePVP,
- 'sst2': Sst2PVP,
- 'cola': ColaPVP,
- 'mrpc': MrpcPVP,
- 'qqp': QqpPVP,
- 'qnli': QnliPVP,
- 'squad': SquadPVP,
- 'race': RacePVP,
- }
|