Fix Issue: switch LLM of SageMaker endpoint doesn't take effect (#8737)

Co-authored-by: Yuanbo Li <ybalbert@amazon.com>
This commit is contained in:
ybalbert001 2024-09-25 09:12:35 +08:00 committed by GitHub
parent 91f70d0bd9
commit 68c7e68a8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -84,8 +84,9 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
Model class for Cohere large language model.
"""
sagemaker_client: Any = None
sagemaker_session: Any = None
predictor: Any = None
sagemaker_endpoint: str = None
def _handle_chat_generate_response(
self,
@ -211,7 +212,7 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
:param user: unique user id
:return: full response or stream response chunk generator result
"""
if not self.sagemaker_client:
if not self.sagemaker_session:
access_key = credentials.get("aws_access_key_id")
secret_key = credentials.get("aws_secret_access_key")
aws_region = credentials.get("aws_region")
@ -226,11 +227,14 @@ class SageMakerLargeLanguageModel(LargeLanguageModel):
else:
boto_session = boto3.Session()
self.sagemaker_client = boto_session.client("sagemaker")
sagemaker_session = Session(boto_session=boto_session, sagemaker_client=self.sagemaker_client)
sagemaker_client = boto_session.client("sagemaker")
self.sagemaker_session = Session(boto_session=boto_session, sagemaker_client=sagemaker_client)
if self.sagemaker_endpoint != credentials.get("sagemaker_endpoint"):
self.sagemaker_endpoint = credentials.get("sagemaker_endpoint")
self.predictor = Predictor(
endpoint_name=credentials.get("sagemaker_endpoint"),
sagemaker_session=sagemaker_session,
endpoint_name=self.sagemaker_endpoint,
sagemaker_session=self.sagemaker_session,
serializer=serializers.JSONSerializer(),
)