mirror of
https://gitee.com/Tencent/Hunyuan3D-1.git
synced 2024-11-29 18:49:33 +08:00
472 lines
20 KiB
Python
472 lines
20 KiB
Python
|
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||
|
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||
|
|
||
|
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||
|
# The below software and/or models in this distribution may have been
|
||
|
# modified by THL A29 Limited ("Tencent Modifications").
|
||
|
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||
|
|
||
|
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||
|
# except for the third-party components listed below.
|
||
|
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||
|
# in the repsective licenses of these third-party components.
|
||
|
# Users must comply with all terms and conditions of original licenses of these third-party
|
||
|
# components and must ensure that the usage of the third party components adheres to
|
||
|
# all relevant laws and regulations.
|
||
|
|
||
|
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||
|
# their software and algorithms, including trained model weights, parameters (including
|
||
|
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||
|
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||
|
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||
|
|
||
|
import inspect
|
||
|
from typing import Any, Dict, Optional
|
||
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||
|
|
||
|
import os
|
||
|
import torch
|
||
|
import numpy as np
|
||
|
from PIL import Image
|
||
|
|
||
|
import diffusers
|
||
|
from diffusers.image_processor import VaeImageProcessor
|
||
|
from diffusers.utils.import_utils import is_xformers_available
|
||
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||
|
from diffusers.utils.torch_utils import randn_tensor
|
||
|
from diffusers.utils.import_utils import is_xformers_available
|
||
|
from diffusers.models.attention_processor import (
|
||
|
Attention,
|
||
|
AttnProcessor,
|
||
|
XFormersAttnProcessor,
|
||
|
AttnProcessor2_0
|
||
|
)
|
||
|
from diffusers import (
|
||
|
AutoencoderKL,
|
||
|
DDPMScheduler,
|
||
|
DiffusionPipeline,
|
||
|
EulerAncestralDiscreteScheduler,
|
||
|
UNet2DConditionModel,
|
||
|
ImagePipelineOutput
|
||
|
)
|
||
|
import transformers
|
||
|
from transformers import (
|
||
|
CLIPImageProcessor,
|
||
|
CLIPTextModel,
|
||
|
CLIPTokenizer,
|
||
|
CLIPVisionModelWithProjection,
|
||
|
CLIPTextModelWithProjection
|
||
|
)
|
||
|
|
||
|
from .utils import to_rgb_image, white_out_background, recenter_img
|
||
|
|
||
|
EXAMPLE_DOC_STRING = """
|
||
|
Examples:
|
||
|
```py
|
||
|
>>> import torch
|
||
|
>>> from diffusers import Hunyuan3d_MVD_XL_Pipeline
|
||
|
|
||
|
>>> pipe = Hunyuan3d_MVD_XL_Pipeline.from_pretrained(
|
||
|
... "Tencent-Hunyuan-3D/MVD-XL", torch_dtype=torch.float16
|
||
|
... )
|
||
|
>>> pipe.to("cuda")
|
||
|
|
||
|
>>> img = Image.open("demo.png")
|
||
|
>>> res_img = pipe(img).images[0]
|
||
|
```
|
||
|
"""
|
||
|
|
||
|
|
||
|
|
||
|
def scale_latents(latents): return (latents - 0.22) * 0.75
|
||
|
def unscale_latents(latents): return (latents / 0.75) + 0.22
|
||
|
def scale_image(image): return (image - 0.5) / 0.5
|
||
|
def scale_image_2(image): return (image * 0.5) / 0.8
|
||
|
def unscale_image(image): return (image * 0.5) + 0.5
|
||
|
def unscale_image_2(image): return (image * 0.8) / 0.5
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
class ReferenceOnlyAttnProc(torch.nn.Module):
|
||
|
def __init__(self, chained_proc, enabled=False, name=None):
|
||
|
super().__init__()
|
||
|
self.enabled = enabled
|
||
|
self.chained_proc = chained_proc
|
||
|
self.name = name
|
||
|
|
||
|
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
|
||
|
encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
|
||
|
if self.enabled:
|
||
|
if mode == 'w': ref_dict[self.name] = encoder_hidden_states
|
||
|
elif mode == 'r': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
|
||
|
else: raise Exception(f"mode should not be {mode}")
|
||
|
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||
|
|
||
|
|
||
|
class RefOnlyNoisedUNet(torch.nn.Module):
|
||
|
def __init__(self, unet, scheduler) -> None:
|
||
|
super().__init__()
|
||
|
self.unet = unet
|
||
|
self.scheduler = scheduler
|
||
|
|
||
|
unet_attn_procs = dict()
|
||
|
for name, _ in unet.attn_processors.items():
|
||
|
if torch.__version__ >= '2.0': default_attn_proc = AttnProcessor2_0()
|
||
|
elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor()
|
||
|
else: default_attn_proc = AttnProcessor()
|
||
|
unet_attn_procs[name] = ReferenceOnlyAttnProc(
|
||
|
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
|
||
|
)
|
||
|
unet.set_attn_processor(unet_attn_procs)
|
||
|
|
||
|
def __getattr__(self, name: str):
|
||
|
try:
|
||
|
return super().__getattr__(name)
|
||
|
except AttributeError:
|
||
|
return getattr(self.unet, name)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
sample: torch.FloatTensor,
|
||
|
timestep: Union[torch.Tensor, float, int],
|
||
|
encoder_hidden_states: torch.Tensor,
|
||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||
|
class_labels: Optional[torch.Tensor] = None,
|
||
|
down_block_res_samples: Optional[Tuple[torch.Tensor]] = None,
|
||
|
mid_block_res_sample: Optional[Tuple[torch.Tensor]] = None,
|
||
|
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||
|
return_dict: bool = True,
|
||
|
**kwargs
|
||
|
):
|
||
|
|
||
|
dtype = self.unet.dtype
|
||
|
|
||
|
# cond_lat add same level noise
|
||
|
cond_lat = cross_attention_kwargs['cond_lat']
|
||
|
noise = torch.randn_like(cond_lat)
|
||
|
|
||
|
noisy_cond_lat = self.scheduler.add_noise(cond_lat, noise, timestep.reshape(-1))
|
||
|
noisy_cond_lat = self.scheduler.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
|
||
|
|
||
|
ref_dict = {}
|
||
|
|
||
|
_ = self.unet(
|
||
|
noisy_cond_lat,
|
||
|
timestep,
|
||
|
encoder_hidden_states = encoder_hidden_states,
|
||
|
class_labels = class_labels,
|
||
|
cross_attention_kwargs = dict(mode="w", ref_dict=ref_dict),
|
||
|
added_cond_kwargs = added_cond_kwargs,
|
||
|
return_dict = return_dict,
|
||
|
**kwargs
|
||
|
)
|
||
|
|
||
|
res = self.unet(
|
||
|
sample,
|
||
|
timestep,
|
||
|
encoder_hidden_states,
|
||
|
class_labels=class_labels,
|
||
|
cross_attention_kwargs = dict(mode="r", ref_dict=ref_dict),
|
||
|
down_block_additional_residuals = [
|
||
|
sample.to(dtype=dtype) for sample in down_block_res_samples
|
||
|
] if down_block_res_samples is not None else None,
|
||
|
mid_block_additional_residual = (
|
||
|
mid_block_res_sample.to(dtype=dtype)
|
||
|
if mid_block_res_sample is not None else None),
|
||
|
added_cond_kwargs = added_cond_kwargs,
|
||
|
return_dict = return_dict,
|
||
|
**kwargs
|
||
|
)
|
||
|
return res
|
||
|
|
||
|
|
||
|
|
||
|
class HunYuan3D_MVD_Std_Pipeline(diffusers.DiffusionPipeline):
|
||
|
def __init__(
|
||
|
self,
|
||
|
vae: AutoencoderKL,
|
||
|
unet: UNet2DConditionModel,
|
||
|
scheduler: KarrasDiffusionSchedulers,
|
||
|
feature_extractor_vae: CLIPImageProcessor,
|
||
|
vision_processor: CLIPImageProcessor,
|
||
|
vision_encoder: CLIPVisionModelWithProjection,
|
||
|
vision_encoder_2: CLIPVisionModelWithProjection,
|
||
|
ramping_coefficients: Optional[list] = None,
|
||
|
add_watermarker: Optional[bool] = None,
|
||
|
safety_checker = None,
|
||
|
):
|
||
|
DiffusionPipeline.__init__(self)
|
||
|
|
||
|
self.register_modules(
|
||
|
vae=vae, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor_vae=feature_extractor_vae,
|
||
|
vision_processor=vision_processor, vision_encoder=vision_encoder, vision_encoder_2=vision_encoder_2,
|
||
|
)
|
||
|
self.register_to_config( ramping_coefficients = ramping_coefficients)
|
||
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||
|
self.default_sample_size = self.unet.config.sample_size
|
||
|
self.watermark = None
|
||
|
self.prepare_init = False
|
||
|
|
||
|
def prepare(self):
|
||
|
assert isinstance(self.unet, UNet2DConditionModel), "unet should be UNet2DConditionModel"
|
||
|
self.unet = RefOnlyNoisedUNet(self.unet, self.scheduler).eval()
|
||
|
self.prepare_init = True
|
||
|
|
||
|
def encode_image(self, image: torch.Tensor, scale_factor: bool = False):
|
||
|
latent = self.vae.encode(image).latent_dist.sample()
|
||
|
return (latent * self.vae.config.scaling_factor) if scale_factor else latent
|
||
|
|
||
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||
|
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||
|
shape = (
|
||
|
batch_size,
|
||
|
num_channels_latents,
|
||
|
int(height) // self.vae_scale_factor,
|
||
|
int(width) // self.vae_scale_factor,
|
||
|
)
|
||
|
if isinstance(generator, list) and len(generator) != batch_size:
|
||
|
raise ValueError(
|
||
|
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||
|
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||
|
)
|
||
|
|
||
|
if latents is None:
|
||
|
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||
|
else:
|
||
|
latents = latents.to(device)
|
||
|
|
||
|
# scale the initial noise by the standard deviation required by the scheduler
|
||
|
latents = latents * self.scheduler.init_noise_sigma
|
||
|
return latents
|
||
|
|
||
|
def _get_add_time_ids(
|
||
|
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||
|
):
|
||
|
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||
|
|
||
|
passed_add_embed_dim = (
|
||
|
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||
|
)
|
||
|
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||
|
|
||
|
if expected_add_embed_dim != passed_add_embed_dim:
|
||
|
raise ValueError(
|
||
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \
|
||
|
f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config." \
|
||
|
f" Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||
|
)
|
||
|
|
||
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||
|
return add_time_ids
|
||
|
|
||
|
def prepare_extra_step_kwargs(self, generator, eta):
|
||
|
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||
|
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||
|
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||
|
# and should be between [0, 1]
|
||
|
|
||
|
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||
|
extra_step_kwargs = {}
|
||
|
if accepts_eta: extra_step_kwargs["eta"] = eta
|
||
|
|
||
|
# check if the scheduler accepts generator
|
||
|
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||
|
if accepts_generator: extra_step_kwargs["generator"] = generator
|
||
|
return extra_step_kwargs
|
||
|
|
||
|
@property
|
||
|
def guidance_scale(self):
|
||
|
return self._guidance_scale
|
||
|
|
||
|
@property
|
||
|
def interrupt(self):
|
||
|
return self._interrupt
|
||
|
|
||
|
@property
|
||
|
def do_classifier_free_guidance(self):
|
||
|
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def __call__(
|
||
|
self,
|
||
|
image: Image.Image = None,
|
||
|
guidance_scale = 2.0,
|
||
|
output_type: Optional[str] = "pil",
|
||
|
num_inference_steps: int = 50,
|
||
|
return_dict: bool = True,
|
||
|
eta: float = 0.0,
|
||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||
|
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||
|
latent: torch.Tensor = None,
|
||
|
guidance_curve = None,
|
||
|
**kwargs
|
||
|
):
|
||
|
if not self.prepare_init:
|
||
|
self.prepare()
|
||
|
|
||
|
here = dict(device=self.vae.device, dtype=self.vae.dtype)
|
||
|
|
||
|
batch_size = 1
|
||
|
num_images_per_prompt = 1
|
||
|
width, height = 512 * 2, 512 * 3
|
||
|
target_size = original_size = (height, width)
|
||
|
|
||
|
self._guidance_scale = guidance_scale
|
||
|
self._cross_attention_kwargs = cross_attention_kwargs
|
||
|
self._interrupt = False
|
||
|
|
||
|
device = self._execution_device
|
||
|
|
||
|
# Prepare timesteps
|
||
|
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||
|
timesteps = self.scheduler.timesteps
|
||
|
|
||
|
# Prepare latent variables
|
||
|
num_channels_latents = self.unet.config.in_channels
|
||
|
latents = self.prepare_latents(
|
||
|
batch_size * num_images_per_prompt,
|
||
|
num_channels_latents,
|
||
|
height,
|
||
|
width,
|
||
|
self.vae.dtype,
|
||
|
device,
|
||
|
generator,
|
||
|
latents=latent,
|
||
|
)
|
||
|
|
||
|
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||
|
|
||
|
|
||
|
# Prepare added time ids & embeddings
|
||
|
text_encoder_projection_dim = 1280
|
||
|
add_time_ids = self._get_add_time_ids(
|
||
|
original_size,
|
||
|
crops_coords_top_left,
|
||
|
target_size,
|
||
|
dtype=self.vae.dtype,
|
||
|
text_encoder_projection_dim=text_encoder_projection_dim,
|
||
|
)
|
||
|
negative_add_time_ids = add_time_ids
|
||
|
|
||
|
# hw: preprocess
|
||
|
cond_image = recenter_img(image)
|
||
|
cond_image = to_rgb_image(image)
|
||
|
image_vae = self.feature_extractor_vae(images=cond_image, return_tensors="pt").pixel_values.to(**here)
|
||
|
image_clip = self.vision_processor(images=cond_image, return_tensors="pt").pixel_values.to(**here)
|
||
|
|
||
|
# hw: get cond_lat from cond_img using vae
|
||
|
cond_lat = self.encode_image(image_vae, scale_factor=False)
|
||
|
negative_lat = self.encode_image(torch.zeros_like(image_vae), scale_factor=False)
|
||
|
cond_lat = torch.cat([negative_lat, cond_lat])
|
||
|
|
||
|
# hw: get visual global embedding using clip
|
||
|
global_embeds_1 = self.vision_encoder(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
|
||
|
global_embeds_2 = self.vision_encoder_2(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
|
||
|
global_embeds = torch.concat([global_embeds_1, global_embeds_2], dim=-1)
|
||
|
|
||
|
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
|
||
|
prompt_embeds = self.uc_text_emb.to(**here)
|
||
|
pooled_prompt_embeds = self.uc_text_emb_2.to(**here)
|
||
|
|
||
|
prompt_embeds = prompt_embeds + global_embeds * ramp
|
||
|
add_text_embeds = pooled_prompt_embeds
|
||
|
|
||
|
if self.do_classifier_free_guidance:
|
||
|
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||
|
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||
|
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||
|
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||
|
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||
|
|
||
|
prompt_embeds = prompt_embeds.to(device)
|
||
|
add_text_embeds = add_text_embeds.to(device)
|
||
|
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||
|
|
||
|
# Denoising loop
|
||
|
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||
|
timestep_cond = None
|
||
|
self._num_timesteps = len(timesteps)
|
||
|
|
||
|
if guidance_curve is None:
|
||
|
guidance_curve = lambda t: guidance_scale
|
||
|
|
||
|
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||
|
for i, t in enumerate(timesteps):
|
||
|
if self.interrupt:
|
||
|
continue
|
||
|
|
||
|
# expand the latents if we are doing classifier free guidance
|
||
|
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||
|
|
||
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||
|
|
||
|
# predict the noise residual
|
||
|
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||
|
|
||
|
noise_pred = self.unet(
|
||
|
latent_model_input,
|
||
|
t,
|
||
|
encoder_hidden_states=prompt_embeds,
|
||
|
timestep_cond=timestep_cond,
|
||
|
cross_attention_kwargs=dict(cond_lat=cond_lat),
|
||
|
added_cond_kwargs=added_cond_kwargs,
|
||
|
return_dict=False,
|
||
|
)[0]
|
||
|
|
||
|
# perform guidance
|
||
|
|
||
|
# cur_guidance_scale = self.guidance_scale
|
||
|
cur_guidance_scale = guidance_curve(t) # 1.5 + 2.5 * ((t/1000)**2)
|
||
|
|
||
|
if self.do_classifier_free_guidance:
|
||
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||
|
noise_pred = noise_pred_uncond + cur_guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||
|
|
||
|
# cur_guidance_scale_topleft = (cur_guidance_scale - 1.0) * 4 + 1.0
|
||
|
# noise_pred_top_left = noise_pred_uncond +
|
||
|
# cur_guidance_scale_topleft * (noise_pred_text - noise_pred_uncond)
|
||
|
# _, _, h, w = noise_pred.shape
|
||
|
# noise_pred[:, :, :h//3, :w//2] = noise_pred_top_left[:, :, :h//3, :w//2]
|
||
|
|
||
|
# compute the previous noisy sample x_t -> x_t-1
|
||
|
latents_dtype = latents.dtype
|
||
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||
|
|
||
|
# call the callback, if provided
|
||
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||
|
progress_bar.update()
|
||
|
|
||
|
latents = unscale_latents(latents)
|
||
|
|
||
|
if output_type=="latent":
|
||
|
image = latents
|
||
|
else:
|
||
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||
|
image = unscale_image(unscale_image_2(image)).clamp(0, 1)
|
||
|
image = [
|
||
|
Image.fromarray((image[0]*255+0.5).clamp_(0, 255).permute(1, 2, 0).cpu().numpy().astype("uint8")),
|
||
|
# self.image_processor.postprocess(image, output_type=output_type)[0],
|
||
|
cond_image.resize((512, 512))
|
||
|
]
|
||
|
|
||
|
if not return_dict: return (image,)
|
||
|
return ImagePipelineOutput(images=image)
|
||
|
|
||
|
def save_pretrained(self, save_directory):
|
||
|
# uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
|
||
|
super().save_pretrained(save_directory)
|
||
|
torch.save(self.uc_text_emb, os.path.join(save_directory, "uc_text_emb.pt"))
|
||
|
torch.save(self.uc_text_emb_2, os.path.join(save_directory, "uc_text_emb_2.pt"))
|
||
|
|
||
|
@classmethod
|
||
|
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||
|
# uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
|
||
|
pipeline = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||
|
pipeline.uc_text_emb = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb.pt"))
|
||
|
pipeline.uc_text_emb_2 = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb_2.pt"))
|
||
|
return pipeline
|