diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py index c388341d5..50f8c73ed 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/jina_tokenizer.py @@ -1,20 +1,32 @@ from os.path import abspath, dirname, join +from threading import Lock from transformers import AutoTokenizer class JinaTokenizer: - @staticmethod - def _get_num_tokens_by_jina_base(text: str) -> int: + _tokenizer = None + _lock = Lock() + + @classmethod + def _get_tokenizer(cls): + if cls._tokenizer is None: + with cls._lock: + if cls._tokenizer is None: + base_path = abspath(__file__) + gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') + cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) + return cls._tokenizer + + @classmethod + def _get_num_tokens_by_jina_base(cls, text: str) -> int: """ use jina tokenizer to get num tokens """ - base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer') - tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path) + tokenizer = cls._get_tokenizer() tokens = tokenizer.encode(text) return len(tokens) - @staticmethod - def get_num_tokens(text: str) -> int: - return JinaTokenizer._get_num_tokens_by_jina_base(text) \ No newline at end of file + @classmethod + def get_num_tokens(cls, text: str) -> int: + return cls._get_num_tokens_by_jina_base(text) \ No newline at end of file