Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks
Paper
•
1908.10084
•
Published
•
10
This is a Cross Encoder model finetuned from distilbert/distilroberta-base on the stsb dataset using the sentence-transformers library. It computes scores for pairs of texts, which can be used for semantic textual similarity, semantic search, paraphrase mining, text classification, clustering, and more.
First install the Sentence Transformers library:
pip install -U sentence-transformers
Then you can load this model and run inference.
from sentence_transformers import CrossEncoder
# Download from the 🤗 Hub
model = CrossEncoder("tomaarsen/reranker-distilroberta-base-stsb")
# Get scores for pairs...
pairs = [
['A man with a hard hat is dancing.', 'A man wearing a hard hat is dancing.'],
['A young child is riding a horse.', 'A child is riding a horse.'],
['A man is feeding a mouse to a snake.', 'The man is feeding a mouse to the snake.'],
['A woman is playing the guitar.', 'A man is playing guitar.'],
['A woman is playing the flute.', 'A man is playing a flute.'],
]
scores = model.predict(pairs)
print(scores.shape)
# [5]
# ... or rank different texts based on similarity to a single text
ranks = model.rank(
'A man with a hard hat is dancing.',
[
'A man wearing a hard hat is dancing.',
'A child is riding a horse.',
'The man is feeding a mouse to the snake.',
'A man is playing guitar.',
'A man is playing a flute.',
]
)
# [{'corpus_id': ..., 'score': ...}, {'corpus_id': ..., 'score': ...}, ...]
stsb-validation and stsb-testCECorrelationEvaluator| Metric | stsb-validation | stsb-test |
|---|---|---|
| pearson | 0.8773 | 0.8503 |
| spearman | 0.8754 | 0.8389 |
sentence1, sentence2, and score| sentence1 | sentence2 | score | |
|---|---|---|---|
| type | string | string | float |
| details |
|
|
|
| sentence1 | sentence2 | score |
|---|---|---|
A plane is taking off. |
An air plane is taking off. |
1.0 |
A man is playing a large flute. |
A man is playing a flute. |
0.76 |
A man is spreading shreded cheese on a pizza. |
A man is spreading shredded cheese on an uncooked pizza. |
0.76 |
BinaryCrossEntropyLosssentence1, sentence2, and score| sentence1 | sentence2 | score | |
|---|---|---|---|
| type | string | string | float |
| details |
|
|
|
| sentence1 | sentence2 | score |
|---|---|---|
A man with a hard hat is dancing. |
A man wearing a hard hat is dancing. |
1.0 |
A young child is riding a horse. |
A child is riding a horse. |
0.95 |
A man is feeding a mouse to a snake. |
The man is feeding a mouse to the snake. |
1.0 |
BinaryCrossEntropyLosseval_strategy: stepsper_device_train_batch_size: 64per_device_eval_batch_size: 64num_train_epochs: 4warmup_ratio: 0.1bf16: Trueoverwrite_output_dir: Falsedo_predict: Falseeval_strategy: stepsprediction_loss_only: Trueper_device_train_batch_size: 64per_device_eval_batch_size: 64per_gpu_train_batch_size: Noneper_gpu_eval_batch_size: Nonegradient_accumulation_steps: 1eval_accumulation_steps: Nonetorch_empty_cache_steps: Nonelearning_rate: 5e-05weight_decay: 0.0adam_beta1: 0.9adam_beta2: 0.999adam_epsilon: 1e-08max_grad_norm: 1.0num_train_epochs: 4max_steps: -1lr_scheduler_type: linearlr_scheduler_kwargs: {}warmup_ratio: 0.1warmup_steps: 0log_level: passivelog_level_replica: warninglog_on_each_node: Truelogging_nan_inf_filter: Truesave_safetensors: Truesave_on_each_node: Falsesave_only_model: Falserestore_callback_states_from_checkpoint: Falseno_cuda: Falseuse_cpu: Falseuse_mps_device: Falseseed: 42data_seed: Nonejit_mode_eval: Falseuse_ipex: Falsebf16: Truefp16: Falsefp16_opt_level: O1half_precision_backend: autobf16_full_eval: Falsefp16_full_eval: Falsetf32: Nonelocal_rank: 0ddp_backend: Nonetpu_num_cores: Nonetpu_metrics_debug: Falsedebug: []dataloader_drop_last: Falsedataloader_num_workers: 0dataloader_prefetch_factor: Nonepast_index: -1disable_tqdm: Falseremove_unused_columns: Truelabel_names: Noneload_best_model_at_end: Falseignore_data_skip: Falsefsdp: []fsdp_min_num_params: 0fsdp_config: {'min_num_params': 0, 'xla': False, 'xla_fsdp_v2': False, 'xla_fsdp_grad_ckpt': False}fsdp_transformer_layer_cls_to_wrap: Noneaccelerator_config: {'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None}deepspeed: Nonelabel_smoothing_factor: 0.0optim: adamw_torchoptim_args: Noneadafactor: Falsegroup_by_length: Falselength_column_name: lengthddp_find_unused_parameters: Noneddp_bucket_cap_mb: Noneddp_broadcast_buffers: Falsedataloader_pin_memory: Truedataloader_persistent_workers: Falseskip_memory_metrics: Trueuse_legacy_prediction_loop: Falsepush_to_hub: Falseresume_from_checkpoint: Nonehub_model_id: Nonehub_strategy: every_savehub_private_repo: Nonehub_always_push: Falsegradient_checkpointing: Falsegradient_checkpointing_kwargs: Noneinclude_inputs_for_metrics: Falseinclude_for_metrics: []eval_do_concat_batches: Truefp16_backend: autopush_to_hub_model_id: Nonepush_to_hub_organization: Nonemp_parameters: auto_find_batch_size: Falsefull_determinism: Falsetorchdynamo: Noneray_scope: lastddp_timeout: 1800torch_compile: Falsetorch_compile_backend: Nonetorch_compile_mode: Nonedispatch_batches: Nonesplit_batches: Noneinclude_tokens_per_second: Falseinclude_num_input_tokens_seen: Falseneftune_noise_alpha: Noneoptim_target_modules: Nonebatch_eval_metrics: Falseeval_on_start: Falseuse_liger_kernel: Falseeval_use_gather_object: Falseaverage_tokens_across_devices: Falseprompts: Nonebatch_sampler: batch_samplermulti_dataset_batch_sampler: proportional| Epoch | Step | Training Loss | Validation Loss | stsb-validation_spearman | stsb-test_spearman |
|---|---|---|---|---|---|
| -1 | -1 | - | - | -0.0150 | - |
| 0.2222 | 20 | 0.6905 | - | - | - |
| 0.4444 | 40 | 0.6548 | - | - | - |
| 0.6667 | 60 | 0.5906 | - | - | - |
| 0.8889 | 80 | 0.5631 | 0.5475 | 0.8589 | - |
| 1.1111 | 100 | 0.5517 | - | - | - |
| 1.3333 | 120 | 0.5473 | - | - | - |
| 1.5556 | 140 | 0.5454 | - | - | - |
| 1.7778 | 160 | 0.5402 | 0.5346 | 0.8760 | - |
| 2.0 | 180 | 0.542 | - | - | - |
| 2.2222 | 200 | 0.5229 | - | - | - |
| 2.4444 | 220 | 0.524 | - | - | - |
| 2.6667 | 240 | 0.5286 | 0.5373 | 0.8744 | - |
| 2.8889 | 260 | 0.5236 | - | - | - |
| 3.1111 | 280 | 0.5269 | - | - | - |
| 3.3333 | 300 | 0.5209 | - | - | - |
| 3.5556 | 320 | 0.5115 | 0.5409 | 0.8754 | - |
| 3.7778 | 340 | 0.5149 | - | - | - |
| 4.0 | 360 | 0.5084 | - | - | - |
| -1 | -1 | - | - | - | 0.8389 |
Carbon emissions were measured using CodeCarbon.
@inproceedings{reimers-2019-sentence-bert,
title = "Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks",
author = "Reimers, Nils and Gurevych, Iryna",
booktitle = "Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing",
month = "11",
year = "2019",
publisher = "Association for Computational Linguistics",
url = "https://arxiv.org/abs/1908.10084",
}
Base model
distilbert/distilroberta-base