Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 42 additions & 48 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np
from huggingface_hub.hf_api import model_info
from skeletoken import TokenizerModel
from transformers import AutoModel, AutoTokenizer
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
Expand All @@ -15,7 +16,7 @@
from model2vec.distill.utils import select_optimal_device
from model2vec.model import StaticModel
from model2vec.quantization import DType, quantize_embeddings
from model2vec.tokenizer import clean_and_create_vocabulary, replace_vocabulary, turn_tokens_into_ids
from model2vec.tokenizer import clean_and_create_vocabulary, turn_tokens_into_ids
from model2vec.vocabulary_quantization import quantize_vocabulary

logger = logging.getLogger(__name__)
Expand All @@ -37,7 +38,8 @@ def distill_from_model(
Distill a staticmodel from a sentence transformer.

This function creates a set of embeddings from a sentence transformer. It does this by doing either
a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.
a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed
vocabulary.

If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
If you don't pass a vocabulary, we use the model's tokenizer directly.
Expand All @@ -51,10 +53,13 @@ def distill_from_model(
If this is 'auto', we don't reduce dimensionality, but still apply PCA.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to
this regex pattern will be removed from the vocabulary.
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled
into a valid regex, we also throw an error.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no
quantization is performed.
:param pooling: The pooling mode to use for creating embeddings. Can be one of:
'mean' (default): mean over all tokens. Robust and works well in most cases.
'last': use the last token's hidden state (often the [EOS] token). Common for decoder-style models.
Expand All @@ -65,59 +70,43 @@ def distill_from_model(

"""
quantize_to = DType(quantize_to)
backend_tokenizer = tokenizer.backend_tokenizer
sif_coefficient, token_remove_regex = _validate_parameters(sif_coefficient, token_remove_pattern)

if vocabulary is None:
vocabulary = []

device = select_optimal_device(device)
original_tokenizer_model = TokenizerModel.from_transformers_tokenizer(tokenizer)

n_tokens_before = len(vocabulary)
# Clean the vocabulary by removing duplicate tokens and tokens that are in the internal vocabulary.
all_tokens, backend_tokenizer = clean_and_create_vocabulary(
tokenizer, vocabulary, token_remove_regex=token_remove_regex
)
n_tokens_after = len([token for token in all_tokens if not token.is_internal])
if n_tokens_before:
logger.info(
f"Adding {n_tokens_after} tokens to the vocabulary. Removed {n_tokens_before - n_tokens_after} tokens during preprocessing."
)

# Copy the original tokenizer model.
tokenizer_model = original_tokenizer_model._deep_copy()
if tokenizer_model.adds_prefix_space is not None:
tokenizer_model.adds_prefix_space = True

# Create the vocabulary in the new tokenizer.
tokenizer_model = clean_and_create_vocabulary(tokenizer_model, vocabulary, token_remove_regex=token_remove_regex)
# Remove the post processor, this is not necessary.
tokenizer_model.post_processor = None

# All tokens in a single list.
all_tokens = tokenizer_model.sorted_vocabulary
if not all_tokens:
raise ValueError("The vocabulary is empty after preprocessing. Please check your token_remove_pattern.")

unk_token = cast(str | None, tokenizer.special_tokens_map.get("unk_token"))
pad_token = cast(str | None, tokenizer.special_tokens_map.get("pad_token"))

# Weird if to satsify mypy
if pad_token is None:
if unk_token is not None:
pad_token = unk_token
logger.warning(
"The pad token is not set. Setting it to the unk token. This is a workaround for models that don't have a pad token."
)
else:
pad_token = unk_token or all_tokens[0].form
logger.warning(
"The pad token is not set. Setting it to the first token in the vocabulary. This is a workaround for models that don't have a pad token."
)

# Replace the vocabulary in the tokenizer with the new vocabulary.
backend_tokenizer = replace_vocabulary(backend_tokenizer, all_tokens, unk_token=unk_token, pad_token=pad_token)
logger.info(f"Creating embeddings for {len(all_tokens)} tokens")
# Convert tokens to IDs
token_ids = turn_tokens_into_ids(all_tokens, tokenizer, unk_token)

# Create the embeddings
# Turn all _new_ tokens into ids using the original tokenizer
token_ids = turn_tokens_into_ids(all_tokens, original_tokenizer_model)

# Create the embeddings using the ids from the original tokenizer.
embeddings = create_embeddings(
tokenized=token_ids,
model=model,
device=device,
pad_token_id=tokenizer.get_vocab()[pad_token],
pad_token_id=tokenizer_model.pad_token_id or 0,
pooling=pooling,
)

# Maybe apply quantization
if vocabulary_quantization is not None:
_, weights = post_process_embeddings(np.asarray(embeddings), None, sif_coefficient=sif_coefficient)
embeddings, token_mapping, weights = quantize_vocabulary(
Expand Down Expand Up @@ -163,7 +152,7 @@ def distill_from_model(
vectors=embeddings,
weights=weights,
token_mapping=token_mapping,
tokenizer=backend_tokenizer,
tokenizer=tokenizer_model.to_tokenizer(),
config=config,
base_model_name=model_name,
language=language,
Expand All @@ -174,13 +163,14 @@ def distill_from_model(
def _validate_parameters(
sif_coefficient: float | None,
token_remove_pattern: str | None,
) -> tuple[float | None, re.Pattern | None]:
) -> tuple[float | None, re.Pattern[str] | None]:
"""
Validate the parameters passed to the distillation function.

:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to
this regex pattern will be removed from the vocabulary.
:return: The SIF coefficient to use.
:raises: ValueError if the regex can't be compiled.

Expand All @@ -189,7 +179,7 @@ def _validate_parameters(
if not 0 < sif_coefficient < 1.0:
raise ValueError("SIF coefficient must be a value > 0 and < 1.0.")

token_remove_regex: re.Pattern | None = None
token_remove_regex: re.Pattern[str] | None = None
if token_remove_pattern is not None:
try:
token_remove_regex = re.compile(token_remove_pattern)
Expand All @@ -215,7 +205,8 @@ def distill(
Distill a staticmodel from a sentence transformer.

This function creates a set of embeddings from a sentence transformer. It does this by doing either
a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed vocabulary.
a forward pass for all subword tokens in the tokenizer, or by doing a forward pass for all tokens in a passed
vocabulary.

If you pass through a vocabulary, we create a custom word tokenizer for that vocabulary.
If you don't pass a vocabulary, we use the model's tokenizer directly.
Expand All @@ -228,10 +219,13 @@ def distill(
If this is 'auto', we don't reduce dimenionality, but still apply PCA.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to
this regex pattern will be removed from the vocabulary.
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming
from `transformers`. If this is True, we will load all components.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no quantization is performed.
:param vocabulary_quantization: The number of clusters to use for vocabulary quantization. If this is None, no
quantization is performed.
:param pooling: The pooling mode to use for creating embeddings. Can be one of:
'mean' (default): mean over all tokens. Robust and works well in most cases.
'last': use the last token's hidden state (often the [EOS] token). Common for decoder-style models.
Expand Down
4 changes: 1 addition & 3 deletions model2vec/tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

from model2vec.tokenizer.tokenizer import (
clean_and_create_vocabulary,
create_tokenizer,
replace_vocabulary,
turn_tokens_into_ids,
)

__all__ = ["clean_and_create_vocabulary", "create_tokenizer", "turn_tokens_into_ids", "replace_vocabulary"]
__all__ = ["clean_and_create_vocabulary", "turn_tokens_into_ids"]
14 changes: 0 additions & 14 deletions model2vec/tokenizer/datamodels.py

This file was deleted.

43 changes: 0 additions & 43 deletions model2vec/tokenizer/model.py

This file was deleted.

42 changes: 0 additions & 42 deletions model2vec/tokenizer/normalizer.py

This file was deleted.

57 changes: 0 additions & 57 deletions model2vec/tokenizer/pretokenizer.py

This file was deleted.

Loading