From 8fe83750b7a3dd07b6ebd952676220a8a03b06f8 Mon Sep 17 00:00:00 2001 From: Yeuoly <45712896+Yeuoly@users.noreply.github.com> Date: Thu, 7 Mar 2024 16:32:37 +0800 Subject: [PATCH] Fix/jina tokenizer cache (#2735) --- .../jina/text_embedding/jina_tokenizer.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index c388341d5..50f8c73ed 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -1,20 +1,32 @@ from os.path import abspath, dirname, join +from threading import Lock from transformers import AutoTokenizer class JinaTokenizer: - @staticmethod - def _get_num_tokens_by_jina_base(text: str) -> int: + _tokenizer = None + _lock = Lock() + + @classmethod + def _get_tokenizer(cls): + if cls._tokenizer is None: + with cls._lock: + if cls._tokenizer is None: + base_path = abspath(__file__) + gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') + cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) + return cls._tokenizer + + @classmethod + def _get_num_tokens_by_jina_base(cls, text: str) -> int: """ use jina tokenizer to get num tokens """ - base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') - tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) + tokenizer = cls._get_tokenizer() tokens = tokenizer.encode(text) return len(tokens) - @staticmethod - def get_num_tokens(text: str) -> int: - return JinaTokenizer._get_num_tokens_by_jina_base(text) \ No newline at end of file + @classmethod + def get_num_tokens(cls, text: str) -> int: + return cls._get_num_tokens_by_jina_base(text) \ No newline at end of file