feat: add support for bedrock Mistral AI model (#3676)

Co-authored-by: Chenhe Gu <guchenhe@gmail.com>
This commit is contained in:
longzhihun 2024-04-22 17:24:02 +08:00 committed by GitHub
parent 4969783383
commit 28236147ee
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 129 additions and 0 deletions

View File

@ -10,3 +10,6 @@
- cohere.command-text-v14
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
- mistral.mistral-large-2402-v1:0
- mistral.mixtral-8x7b-instruct-v0:1
- mistral.mistral-7b-instruct-v0:2

View File

@ -449,6 +449,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
human_prompt_prefix = "\n[INST]"
human_prompt_postfix = "[\\INST]\n"
ai_prompt = ""
elif model_prefix == "mistral":
human_prompt_prefix = "<s>[INST]"
human_prompt_postfix = "[\\INST]\n"
ai_prompt = "\n\nAssistant:"
elif model_prefix == "amazon":
human_prompt_prefix = "\n\nUser:"
@ -519,6 +524,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
payload["frequencyPenalty"] = {model_parameters.get("frequencyPenalty")}
if model_parameters.get("countPenalty"):
payload["countPenalty"] = {model_parameters.get("countPenalty")}
elif model_prefix == "mistral":
payload["temperature"] = model_parameters.get("temperature")
payload["top_p"] = model_parameters.get("top_p")
payload["max_tokens"] = model_parameters.get("max_tokens")
payload["prompt"] = self._convert_messages_to_prompt(prompt_messages, model_prefix)
payload["stop"] = stop[:10] if stop else []
elif model_prefix == "anthropic":
payload = { **model_parameters }
@ -648,6 +660,11 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
output = response_body.get("generation").strip('\n')
prompt_tokens = response_body.get("prompt_token_count")
completion_tokens = response_body.get("generation_token_count")
elif model_prefix == "mistral":
output = response_body.get("outputs")[0].get("text")
prompt_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-input-token-count')
completion_tokens = response.get('ResponseMetadata').get('HTTPHeaders').get('x-amzn-bedrock-output-token-count')
else:
raise ValueError(f"Got unknown model prefix {model_prefix} when handling block response")
@ -731,6 +748,10 @@ class BedrockLargeLanguageModel(LargeLanguageModel):
content_delta = payload.get("text")
finish_reason = payload.get("finish_reason")
elif model_prefix == "mistral":
content_delta = payload.get('outputs')[0].get("text")
finish_reason = payload.get('outputs')[0].get("stop_reason")
elif model_prefix == "meta":
content_delta = payload.get("generation").strip('\n')
finish_reason = payload.get("stop_reason")

View File

@ -0,0 +1,39 @@
model: mistral.mistral-7b-instruct-v0:2
label:
en_US: Mistral 7B Instruct
model_type: llm
model_properties:
mode: completion
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
required: false
default: 0.5
- name: top_p
use_template: top_p
required: false
default: 0.9
- name: top_k
use_template: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 50
max: 200
- name: max_tokens
use_template: max_tokens
required: true
default: 512
min: 1
max: 8192
pricing:
input: '0.00015'
output: '0.0002'
unit: '0.00001'
currency: USD

View File

@ -0,0 +1,27 @@
model: mistral.mistral-large-2402-v1:0
label:
en_US: Mistral Large
model_type: llm
model_properties:
mode: completion
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
required: false
default: 0.7
- name: top_p
use_template: top_p
required: false
default: 1
- name: max_tokens
use_template: max_tokens
required: true
default: 512
min: 1
max: 4096
pricing:
input: '0.008'
output: '0.024'
unit: '0.001'
currency: USD

View File

@ -0,0 +1,39 @@
model: mistral.mixtral-8x7b-instruct-v0:1
label:
en_US: Mixtral 8X7B Instruct
model_type: llm
model_properties:
mode: completion
context_size: 32000
parameter_rules:
- name: temperature
use_template: temperature
required: false
default: 0.5
- name: top_p
use_template: top_p
required: false
default: 0.9
- name: top_k
use_template: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
default: 50
max: 200
- name: max_tokens
use_template: max_tokens
required: true
default: 512
min: 1
max: 8192
pricing:
input: '0.00045'
output: '0.0007'
unit: '0.00001'
currency: USD