mirror of
https://gitee.com/dify_ai/dify.git
synced 2024-11-29 17:58:19 +08:00
Fix/jina tokenizer cache (#2735)
This commit is contained in:
parent
1809f05904
commit
8fe83750b7
@ -1,20 +1,32 @@
|
||||
from os.path import abspath, dirname, join
|
||||
from threading import Lock
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
class JinaTokenizer:
|
||||
@staticmethod
|
||||
def _get_num_tokens_by_jina_base(text: str) -> int:
|
||||
_tokenizer = None
|
||||
_lock = Lock()
|
||||
|
||||
@classmethod
|
||||
def _get_tokenizer(cls):
|
||||
if cls._tokenizer is None:
|
||||
with cls._lock:
|
||||
if cls._tokenizer is None:
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
|
||||
cls._tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
return cls._tokenizer
|
||||
|
||||
@classmethod
|
||||
def _get_num_tokens_by_jina_base(cls, text: str) -> int:
|
||||
"""
|
||||
use jina tokenizer to get num tokens
|
||||
"""
|
||||
base_path = abspath(__file__)
|
||||
gpt2_tokenizer_path = join(dirname(base_path), 'tokenizer')
|
||||
tokenizer = AutoTokenizer.from_pretrained(gpt2_tokenizer_path)
|
||||
tokenizer = cls._get_tokenizer()
|
||||
tokens = tokenizer.encode(text)
|
||||
return len(tokens)
|
||||
|
||||
@staticmethod
|
||||
def get_num_tokens(text: str) -> int:
|
||||
return JinaTokenizer._get_num_tokens_by_jina_base(text)
|
||||
@classmethod
|
||||
def get_num_tokens(cls, text: str) -> int:
|
||||
return cls._get_num_tokens_by_jina_base(text)
|
Loading…
Reference in New Issue
Block a user