init
95
LICENSE.txt
Normal file
@ -0,0 +1,95 @@
|
||||
TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
Tencent Hunyuan3D Release Date: 2024.11.5
|
||||
By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
|
||||
|
||||
1.DEFINITIONS.
|
||||
|
||||
a.“Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
|
||||
b.“Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
|
||||
c.“Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
|
||||
d.“Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
|
||||
e.“Non-Commercial” shall mean a use of the Tencent Hunyuan Works for academic, research and education purposes only.
|
||||
f.“Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
|
||||
g.“Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
|
||||
h.“Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
|
||||
i.“Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
|
||||
j.“Tencent,” “We” or “Us” shall mean THL A29 Limited.
|
||||
k.“Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan 3D released at [https://github.com/Tencent/Hunyuan3D-1/tree/main].
|
||||
l.“Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
|
||||
m.“Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
|
||||
n.“including” shall mean including but not limited to.
|
||||
|
||||
2.GRANT OF RIGHTS.
|
||||
|
||||
We grant You a non-exclusive, non-transferable, non-commercial, royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
|
||||
|
||||
3.DISTRIBUTION.
|
||||
|
||||
You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works provided that You meet all of the following conditions:
|
||||
a.You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
|
||||
b.You must cause any modified files to carry prominent notices stating that You changed the files;
|
||||
c.You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
|
||||
d.All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Non-Commercial License Agreement, Copyright © 2024 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
|
||||
You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement. If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
|
||||
|
||||
4.ADDITIONAL NON-COMMERCIAL TERMS
|
||||
|
||||
If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
|
||||
|
||||
5.RULES OF USE.
|
||||
|
||||
a.Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
|
||||
b.You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other large language model (other than Tencent Hunyuan or Model Derivatives thereof).
|
||||
|
||||
6.INTELLECTUAL PROPERTY.
|
||||
|
||||
a.Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
|
||||
b.No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
|
||||
c.If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
|
||||
d.Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
|
||||
|
||||
7.DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
|
||||
|
||||
a.We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
|
||||
b.UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
|
||||
c.TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
||||
|
||||
8.SURVIVAL AND TERMINATION.
|
||||
|
||||
a.The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
|
||||
b.We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
|
||||
|
||||
9.GOVERNING LAW AND JURISDICTION.
|
||||
|
||||
a.This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
|
||||
b.Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
|
||||
|
||||
|
||||
EXHIBIT A
|
||||
|
||||
ACCEPTABLE USE POLICY
|
||||
|
||||
Tencent reserves the right to update this Acceptable Use Policy from time to time.
|
||||
Last modified: 2024.11.5
|
||||
|
||||
Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
|
||||
1.In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
|
||||
2.To harm Yourself or others;
|
||||
3.To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
|
||||
4.To override or circumvent the safety guardrails and safeguards We have put in place;
|
||||
5.For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
||||
6.To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
|
||||
7.To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
|
||||
8.To intentionally defame, disparage or otherwise harass others;
|
||||
9.To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
|
||||
10.To generate or disseminate personal identifiable information with the purpose of harming others;
|
||||
11.To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
|
||||
12.To impersonate another individual without consent, authorization, or legal right;
|
||||
13.To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
|
||||
14.In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
|
||||
15.To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
|
||||
16.For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
|
||||
17.To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
||||
18.For military purposes;
|
||||
19.To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.
|
||||
THIS ACCEPTABLE USE POLICY INCORPORATES BY REFERENCE THE USER-BASED RESTRICTIONS OUTLINED IN THE CREATIVEML OPEN RAIL++-M LICENSE. https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/LICENSE.md
|
193
Notice
Normal file
@ -0,0 +1,193 @@
|
||||
Usage and Legal Notices:
|
||||
|
||||
Tencent is pleased to support the open source community by making Hunyuan 3D available.
|
||||
|
||||
Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. The below software and/or models in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT except for the third-party components listed below. Hunyuan 3D does not impose any additional limitations beyond what is outlined in the repsective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
||||
|
||||
For avoidance of doubts, Hunyuan 3D means the large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
|
||||
Other dependencies and licenses:
|
||||
|
||||
|
||||
Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
--------------------------------------------------------------------
|
||||
1. instantmesh
|
||||
Copyright (c) instantmesh original author and authors
|
||||
Please note this software has been modified by Tencent in this distribution.
|
||||
|
||||
|
||||
Terms of the Apache License Version 2.0:
|
||||
--------------------------------------------------------------------
|
||||
Apache License
|
||||
|
||||
Version 2.0, January 2004
|
||||
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
||||
|
||||
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
||||
|
||||
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
||||
|
||||
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
||||
|
||||
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
|
||||
For the license of other third party dependencies, please refer to the following URL:
|
||||
https://github.com/TencentARC/InstantMesh/blob/main/LICENSE
|
||||
https://github.com/TencentARC/InstantMesh/tree/main?tab=readme-ov-file#-acknowledgements
|
||||
|
||||
|
||||
Open Source Model Licensed under the MIT and CreativeML Open RAIL++-M License:
|
||||
The below Model in this distribution may have been modified by Tencent.
|
||||
--------------------------------------------------------------------
|
||||
1. Stable Diffusion
|
||||
Copyright (c) 2022 Stability AI and contributors
|
||||
|
||||
|
||||
Terms of the MIT and CreativeML Open RAIL++-M License:
|
||||
--------------------------------------------------------------------
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
CreativeML Open RAIL++-M License
|
||||
dated November 24, 2022
|
||||
|
||||
Section I: PREAMBLE
|
||||
|
||||
Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation.
|
||||
|
||||
Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations.
|
||||
|
||||
In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation.
|
||||
|
||||
Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI.
|
||||
|
||||
This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model.
|
||||
|
||||
NOW THEREFORE, You and Licensor agree as follows:
|
||||
|
||||
1. Definitions
|
||||
|
||||
- "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
|
||||
- "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
|
||||
- "Output" means the results of operating a Model as embodied in informational content resulting therefrom.
|
||||
- "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
|
||||
- "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
|
||||
- "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
|
||||
- "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
|
||||
- "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
|
||||
- "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator.
|
||||
- "Third Parties" means individuals or legal entities that are not under common control with Licensor or You.
|
||||
- "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
||||
- "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model.
|
||||
|
||||
Section II: INTELLECTUAL PROPERTY RIGHTS
|
||||
|
||||
Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model.
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed.
|
||||
|
||||
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
|
||||
|
||||
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
|
||||
Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
|
||||
You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
|
||||
You must cause any modified files to carry prominent notices stating that You changed the files;
|
||||
You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
|
||||
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).
|
||||
6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License.
|
||||
|
||||
Section IV: OTHER PROVISIONS
|
||||
|
||||
7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License.
|
||||
8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors.
|
||||
9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License.
|
||||
10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
||||
11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
||||
12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Attachment A
|
||||
|
||||
Use Restrictions
|
||||
|
||||
You agree not to use the Model or Derivatives of the Model:
|
||||
|
||||
- In any way that violates any applicable national, federal, state, local or international law or regulation;
|
||||
- For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
|
||||
- To generate or disseminate verifiably false information and/or content with the purpose of harming others;
|
||||
- To generate or disseminate personal identifiable information that can be used to harm an individual;
|
||||
- To defame, disparage or otherwise harass others;
|
||||
- For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation;
|
||||
- For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics;
|
||||
- To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
|
||||
- For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories;
|
||||
- To provide medical advice and medical results interpretation;
|
||||
- To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use).
|
189
README.md
Normal file
@ -0,0 +1,189 @@
|
||||
<!-- ## **Hunyuan3D-1.0** -->
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/logo.png" height=200>
|
||||
</p>
|
||||
|
||||
# Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation
|
||||
|
||||
[\[Code\]](https://github.com/tencent/Hunyuan3D-1)
|
||||
[\[Huggingface\]](https://huggingface.co/tencent/Hunyuan3D-1)
|
||||
[\[Report\]](https://3d.hunyuan.tencent.com/hunyuan3d.pdf)
|
||||
|
||||
|
||||
## 🔥🔥🔥 News!!
|
||||
|
||||
* Nov 5, 2024: 💬 We support demo running image_to_3d generation now. Please check the [script](#using-gradio) below.
|
||||
* Nov 5, 2024: 💬 We support demo running text_to_3d generation now. Please check the [script](#using-gradio) below.
|
||||
|
||||
|
||||
## 📑 Open-source Plan
|
||||
|
||||
- [x] Inference
|
||||
- [x] Checkpoints
|
||||
- [ ] Baking related
|
||||
- [ ] Training
|
||||
- [ ] ComfyUI
|
||||
- [ ] Distillation Version
|
||||
- [ ] TensorRT Version
|
||||
|
||||
|
||||
|
||||
## **Abstract**
|
||||
<p align="center">
|
||||
<img src="./assets/teaser.png" height=450>
|
||||
</p>
|
||||
|
||||
While 3D generative models have greatly improved artists' workflows, the existing diffusion models for 3D generation suffer from slow generation and poor generalization. To address this issue, we propose a two-stage approach named Hunyuan3D-1.0 including a lite version and a standard version, that both support text- and image-conditioned generation.
|
||||
|
||||
In the first stage, we employ a multi-view diffusion model that efficiently generates multi-view RGB in approximately 4 seconds. These multi-view images capture rich details of the 3D asset from different viewpoints, relaxing the tasks from single-view to multi-view reconstruction. In the second stage, we introduce a feed-forward reconstruction model that rapidly and faithfully reconstructs the 3D asset given the generated multi-view images in approximately 7 seconds. The reconstruction network learns to handle noises and in-consistency introduced by the multi-view diffusion and leverages the available information from the condition image to efficiently recover the 3D structure.
|
||||
|
||||
Our framework involves the text-to-image model, i.e., Hunyuan-DiT, making it a unified framework to support both text- and image-conditioned 3D generation. Our standard version has 3x more parameters than our lite and other existing model. Our Hunyuan3D-1.0 achieves an impressive balance between speed and quality, significantly reducing generation time while maintaining the quality and diversity of the produced assets.
|
||||
|
||||
|
||||
## 🎉 **Hunyuan3D-1 Architecture**
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/overview_3.png" height=400>
|
||||
</p>
|
||||
|
||||
|
||||
## 📈 Comparisons
|
||||
|
||||
We have evaluated Hunyuan3D-1.0 with other open-source 3d-generation methods, our Hunyuan3D-1.0 received the highest user preference across 5 metrics. Details in the picture on the lower left.
|
||||
|
||||
The lite model takes around 10 seconds to produce a 3D mesh from a single image on an NVIDIA A100 GPU, while the standard model takes roughly 25 seconds. The plot laid out in the lower right demonstrates that Hunyuan3D-1.0 achieves an optimal balance between quality and efficiency.
|
||||
|
||||
<p align="center">
|
||||
<img src="./assets/radar.png" height=300>
|
||||
<img src="./assets/runtime.png" height=300>
|
||||
</p>
|
||||
|
||||
## Get Started
|
||||
|
||||
#### Begin by cloning the repository:
|
||||
|
||||
```shell
|
||||
git clone https://github.com/tencent/Hunyuan3D-1
|
||||
cd Hunyuan3D-1
|
||||
```
|
||||
|
||||
#### Installation Guide for Linux
|
||||
|
||||
We provide an env_install.sh script file for setting up environment.
|
||||
|
||||
We recommend python3.9 and CUDA11.7+
|
||||
```
|
||||
conda create -n hunyuan3d-1 python=3.9
|
||||
conda activate hunyuan3d-1
|
||||
bash env_install.sh
|
||||
```
|
||||
|
||||
#### Download Pretrained Models
|
||||
|
||||
The models are available at [https://huggingface.co/spaces/tencent/Hunyuan3D-1](https://huggingface.co/spaces/tencent/Hunyuan3D-1):
|
||||
|
||||
+ `Hunyuan3D-1/lite`, lite model for multi-view generation.
|
||||
+ `Hunyuan3D-1/std`, standard model for multi-view generation.
|
||||
+ `Hunyuan3D-1/svrm`, sparse-view reconstruction model.
|
||||
|
||||
|
||||
To download the model, first install the huggingface-cli. (Detailed instructions are available [here](https://huggingface.co/docs/huggingface_hub/guides/cli).)
|
||||
|
||||
```shell
|
||||
python3 -m pip install "huggingface_hub[cli]"
|
||||
```
|
||||
|
||||
Then download the model using the following commands:
|
||||
|
||||
```shell
|
||||
mkdir weights
|
||||
huggingface-cli download tencent/Hunyuan3D-1 --local-dir ./weights
|
||||
|
||||
mkdir weights/hunyuanDiT
|
||||
huggingface-cli download Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled --local-dir ./weights/hunyuanDiT
|
||||
```
|
||||
|
||||
#### Inference
|
||||
For text to 3d generation, we supports bilingual Chinese and English, you can use the following command to inference.
|
||||
```python
|
||||
python3 main.py \
|
||||
--text_prompt "a lovely rabbit" \
|
||||
--save_folder ./outputs/test/ \
|
||||
--max_faces_num 90000 \
|
||||
--do_texture_mapping \
|
||||
--do_render
|
||||
```
|
||||
|
||||
For image to 3d generation, you can use the following command to inference.
|
||||
```python
|
||||
python3 main.py \
|
||||
--image_prompt "/path/to/your/image" \
|
||||
--save_folder ./outputs/test/ \
|
||||
--max_faces_num 90000 \
|
||||
--do_texture_mapping \
|
||||
--do_render
|
||||
```
|
||||
We list some more useful configurations for easy usage:
|
||||
|
||||
| Argument | Default | Description |
|
||||
|:------------------:|:---------:|:---------------------------------------------------:|
|
||||
|`--text_prompt` | None |The text prompt for 3D generation |
|
||||
|`--image_prompt` | None |The image prompt for 3D generation |
|
||||
|`--t2i_seed` | 0 |The random seed for generating images |
|
||||
|`--t2i_steps` | 25 |The number of steps for sampling of text to image |
|
||||
|`--gen_seed` | 0 |The random seed for generating 3d generation |
|
||||
|`--gen_steps` | 50 |The number of steps for sampling of 3d generation |
|
||||
|`--max_faces_numm` | 90000 |The limit number of faces of 3d mesh |
|
||||
|`--save_memory` | False |text2image will move to cpu automatically|
|
||||
|`--do_texture_mapping` | False |Change vertex shadding to texture shading |
|
||||
|`--do_render` | False |render gif |
|
||||
|
||||
|
||||
We have also prepared scripts with different configurations for reference
|
||||
```bash
|
||||
bash scripts/text_to_3d_demo.sh
|
||||
bash scripts/text_to_3d_fast_demo.sh
|
||||
bash scripts/image_to_3d_demo.sh
|
||||
bash scripts/image_to_3d_fast_demo.sh
|
||||
```
|
||||
|
||||
This example requires ~40GB VRAM to run.
|
||||
|
||||
#### Using Gradio
|
||||
|
||||
We have prepared two versions of multi-view generation, std and lite.
|
||||
|
||||
For better results, the std version of the running script is as follows
|
||||
```shell
|
||||
python3 app.py
|
||||
```
|
||||
|
||||
For faster speed, you can use the lite version by adding the --use_lite parameter.
|
||||
|
||||
```shell
|
||||
python3 app.py --use_lite
|
||||
```
|
||||
|
||||
Then the demo can be accessed through http://0.0.0.0:8080. It should be noted that the 0.0.0.0 here needs to be X.X.X.X with your server IP.
|
||||
|
||||
## Camera Parameters
|
||||
|
||||
Output views are a fixed set of camera poses:
|
||||
|
||||
+ Azimuth (relative to input view): `+0, +60, +120, +180, +240, +300`.
|
||||
|
||||
|
||||
<!-- ## Citation
|
||||
|
||||
If you found this repository helpful, please cite our report:
|
||||
```bibtex
|
||||
@misc{xxx_todo,
|
||||
title={Hunyuan3D-1.0: First Unified Framework for Text-to-3D and Image-to-3D Generation},
|
||||
author={},
|
||||
year={2024},
|
||||
eprint={},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
``` -->
|
302
app.py
Normal file
@ -0,0 +1,302 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import warnings
|
||||
warnings.simplefilter('ignore', category=UserWarning)
|
||||
warnings.simplefilter('ignore', category=FutureWarning)
|
||||
warnings.simplefilter('ignore', category=DeprecationWarning)
|
||||
|
||||
import gradio as gr
|
||||
from glob import glob
|
||||
import shutil
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--use_lite", default=False, action="store_true")
|
||||
parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str)
|
||||
parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str)
|
||||
parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str)
|
||||
parser.add_argument("--save_memory", default=False, action="store_true")
|
||||
parser.add_argument("--device", default="cuda:0", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
################################################################
|
||||
|
||||
CONST_PORT = 8080
|
||||
CONST_MAX_QUEUE = 1
|
||||
CONST_SERVER = '0.0.0.0'
|
||||
|
||||
CONST_HEADER = '''
|
||||
<h2><b>Official 🤗 Gradio Demo</b></h2><h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'><b>Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D
|
||||
Generationr</b></a></h2>
|
||||
Code: <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>GitHub</a>. Techenical report: <a href='https://arxiv.org/abs/placeholder' target='_blank'>ArXiv</a>.
|
||||
|
||||
❗️❗️❗️**Important Notes:**
|
||||
- Our demo can export a .obj mesh with vertex colors or a .glb mesh by default.
|
||||
- If you check "texture mapping", it will export a .obj mesh with a texture map or a .glb mesh.
|
||||
- If you check "render Gif", it will export gif image rendering .glb file.
|
||||
- If the result is unsatisfying, please try a different **seed value** (Default: 0).
|
||||
'''
|
||||
|
||||
CONST_CITATION = r"""
|
||||
If HunYuan3D-1 is helpful, please help to ⭐ the <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>Github Repo</a>. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/tencent/Hunyuan3D-1?style=social)](https://github.com/tencent/Hunyuan3D-1)
|
||||
---
|
||||
📝 **Citation**
|
||||
If you find our work useful for your research or applications, please cite using this bibtex:
|
||||
```bibtex
|
||||
@misc{xxx,
|
||||
title={Hunyuan3D-1.0: First Unified Framework for Text-to-3D and Image-to-3D Generation},
|
||||
author={},
|
||||
year={2024},
|
||||
eprint={},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.CV}
|
||||
}
|
||||
```
|
||||
"""
|
||||
|
||||
################################################################
|
||||
|
||||
def get_example_img_list():
|
||||
print('Loading example img list ...')
|
||||
return sorted(glob('./demos/example_*.png'))
|
||||
|
||||
def get_example_txt_list():
|
||||
print('Loading example txt list ...')
|
||||
txt_list = list()
|
||||
for line in open('./demos/example_list.txt'):
|
||||
txt_list.append(line.strip())
|
||||
return txt_list
|
||||
|
||||
example_is = get_example_img_list()
|
||||
example_ts = get_example_txt_list()
|
||||
################################################################
|
||||
|
||||
from infer import seed_everything, save_gif
|
||||
from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
|
||||
|
||||
|
||||
worker_xbg = Removebg()
|
||||
print(f"loading {args.text2image_path}")
|
||||
worker_t2i = Text2Image(
|
||||
pretrain = args.text2image_path,
|
||||
device = args.device,
|
||||
save_memory = args.save_memory
|
||||
)
|
||||
worker_i2v = Image2Views(
|
||||
use_lite = args.use_lite,
|
||||
device = args.device
|
||||
)
|
||||
worker_v23 = Views2Mesh(
|
||||
args.mv23d_cfg_path,
|
||||
args.mv23d_ckt_path,
|
||||
use_lite = args.use_lite,
|
||||
device = args.device
|
||||
)
|
||||
worker_gif = GifRenderer(args.device)
|
||||
|
||||
def stage_0_t2i(text, image, seed, step):
|
||||
# prepare save_folder
|
||||
os.makedirs('./outputs/app_output', exist_ok=True)
|
||||
exists = set(int(_) for _ in os.listdir('./outputs/app_output') if not _.startswith("."))
|
||||
if len(exists) == 30: shutil.rmtree(f"./outputs/app_output/0");cur_id = 0
|
||||
else: cur_id = min(set(range(30)) - exists)
|
||||
if os.path.exists(f"./outputs/app_output/{(cur_id + 1) % 30}"):
|
||||
shutil.rmtree(f"./outputs/app_output/{(cur_id + 1) % 30}")
|
||||
save_folder = f'./outputs/app_output/{cur_id}'
|
||||
os.makedirs(save_folder, exist_ok=True)
|
||||
|
||||
dst = save_folder + '/img.png'
|
||||
|
||||
if not text:
|
||||
if image is None:
|
||||
return dst, save_folder
|
||||
raise gr.Error("Upload image or provide text ...")
|
||||
image.save(dst)
|
||||
return dst, save_folder
|
||||
|
||||
image = worker_t2i(text, seed, step)
|
||||
image.save(dst)
|
||||
dst = worker_xbg(image, save_folder)
|
||||
return dst, save_folder
|
||||
|
||||
def stage_1_xbg(image, save_folder):
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
dst = save_folder + '/img_nobg.png'
|
||||
rgba = worker_xbg(image)
|
||||
rgba.save(dst)
|
||||
return dst
|
||||
|
||||
def stage_2_i2v(image, seed, step, save_folder):
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
gif_dst = save_folder + '/views.gif'
|
||||
res_img, pils = worker_i2v(image, seed, step)
|
||||
save_gif(pils, gif_dst)
|
||||
views_img, cond_img = res_img[0], res_img[1]
|
||||
img_array = np.asarray(views_img, dtype=np.uint8)
|
||||
show_img = rearrange(img_array, '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
||||
show_img = show_img[worker_i2v.order, ...]
|
||||
show_img = rearrange(show_img, '(n m) h w c -> (n h) (m w) c', n=2, m=3)
|
||||
show_img = Image.fromarray(show_img)
|
||||
return views_img, cond_img, show_img
|
||||
|
||||
def stage_3_v23(
|
||||
views_pil,
|
||||
cond_pil,
|
||||
seed,
|
||||
save_folder,
|
||||
target_face_count = 30000,
|
||||
do_texture_mapping = True,
|
||||
do_render =True
|
||||
):
|
||||
do_texture_mapping = do_texture_mapping or do_render
|
||||
obj_dst = save_folder + '/mesh_with_colors.obj'
|
||||
glb_dst = save_folder + '/mesh.glb'
|
||||
worker_v23(
|
||||
views_pil,
|
||||
cond_pil,
|
||||
seed = seed,
|
||||
save_folder = save_folder,
|
||||
target_face_count = target_face_count,
|
||||
do_texture_mapping = do_texture_mapping
|
||||
)
|
||||
return obj_dst, glb_dst
|
||||
|
||||
def stage_4_gif(obj_dst, save_folder, do_render_gif=True):
|
||||
if not do_render_gif: return None
|
||||
gif_dst = save_folder + '/output.gif'
|
||||
worker_gif(
|
||||
save_folder + '/mesh.obj',
|
||||
gif_dst_path = gif_dst
|
||||
)
|
||||
return gif_dst
|
||||
|
||||
#===============================================================
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(CONST_HEADER)
|
||||
with gr.Row(variant="panel"):
|
||||
with gr.Column(scale=2):
|
||||
with gr.Tab("Text to 3D"):
|
||||
with gr.Column():
|
||||
text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。', lines=1, max_lines=10, label='Input text')
|
||||
with gr.Row():
|
||||
textgen_seed = gr.Number(value=0, label="T2I seed", precision=0)
|
||||
textgen_step = gr.Number(value=25, label="T2I step", precision=0)
|
||||
textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
|
||||
textgen_STEP = gr.Number(value=50, label="Gen step", precision=0)
|
||||
textgen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
|
||||
|
||||
with gr.Row():
|
||||
textgen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False, interactive=True)
|
||||
textgen_do_render_gif = gr.Checkbox(label="Render gif", value=False, interactive=True)
|
||||
textgen_submit = gr.Button("Generate", variant="primary")
|
||||
|
||||
with gr.Row():
|
||||
gr.Examples(examples=example_ts, inputs=[text], label="Txt examples", examples_per_page=10)
|
||||
|
||||
with gr.Tab("Image to 3D"):
|
||||
with gr.Column():
|
||||
input_image = gr.Image(label="Input image",
|
||||
width=256, height=256, type="pil",
|
||||
image_mode="RGBA", sources="upload",
|
||||
interactive=True)
|
||||
with gr.Row():
|
||||
imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
|
||||
imggen_STEP = gr.Number(value=50, label="Gen step", precision=0)
|
||||
imggen_max_faces = gr.Number(value=90000, label="max number of faces", precision=0)
|
||||
|
||||
with gr.Row():
|
||||
imggen_do_texture_mapping = gr.Checkbox(label="texture mapping", value=False, interactive=True)
|
||||
imggen_do_render_gif = gr.Checkbox(label="Render gif", value=False, interactive=True)
|
||||
imggen_submit = gr.Button("Generate", variant="primary")
|
||||
with gr.Row():
|
||||
gr.Examples(examples=example_is, inputs=[input_image], label="Img examples", examples_per_page=10)
|
||||
|
||||
with gr.Column(scale=3):
|
||||
with gr.Tab("rembg image"):
|
||||
rem_bg_image = gr.Image(label="No backgraound image",
|
||||
width=256, height=256, type="pil",
|
||||
image_mode="RGBA", interactive=False)
|
||||
|
||||
with gr.Tab("Multi views"):
|
||||
result_image = gr.Image(label="Multi views", type="pil", interactive=False)
|
||||
with gr.Tab("Obj"):
|
||||
result_3dobj = gr.Model3D(label="Output obj", interactive=False)
|
||||
with gr.Tab("Glb"):
|
||||
result_3dglb = gr.Model3D(label="Output glb", interactive=False)
|
||||
gr.Markdown("The glb file displayed on the grario will be dark. We recommend downloading and opening it with 3D software, such as Blender, MeshLab, etc")
|
||||
with gr.Tab("GIF"):
|
||||
result_gif = gr.Image(label="Rendered GIF", interactive=False)
|
||||
|
||||
#===============================================================
|
||||
|
||||
none = gr.State(None)
|
||||
save_folder = gr.State()
|
||||
cond_image = gr.State()
|
||||
views_image = gr.State()
|
||||
text_image = gr.State()
|
||||
|
||||
textgen_submit.click(
|
||||
fn=stage_0_t2i, inputs=[text, none, textgen_seed, textgen_step],
|
||||
outputs=[rem_bg_image, save_folder],
|
||||
).success(
|
||||
fn=stage_2_i2v, inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder],
|
||||
outputs=[views_image, cond_image, result_image],
|
||||
).success(
|
||||
fn=stage_3_v23, inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces, textgen_do_texture_mapping, textgen_do_render_gif],
|
||||
outputs=[result_3dobj, result_3dglb],
|
||||
).success(
|
||||
fn=stage_4_gif, inputs=[result_3dglb, save_folder, textgen_do_render_gif],
|
||||
outputs=[result_gif],
|
||||
).success(lambda: print('Text_to_3D Done ...'))
|
||||
|
||||
imggen_submit.click(
|
||||
fn=stage_0_t2i, inputs=[none, input_image, textgen_seed, textgen_step],
|
||||
outputs=[text_image, save_folder],
|
||||
).success(
|
||||
fn=stage_1_xbg, inputs=[text_image, save_folder],
|
||||
outputs=[rem_bg_image],
|
||||
).success(
|
||||
fn=stage_2_i2v, inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder],
|
||||
outputs=[views_image, cond_image, result_image],
|
||||
).success(
|
||||
fn=stage_3_v23, inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces, imggen_do_texture_mapping, imggen_do_render_gif],
|
||||
outputs=[result_3dobj, result_3dglb],
|
||||
).success(
|
||||
fn=stage_4_gif, inputs=[result_3dglb, save_folder, imggen_do_render_gif],
|
||||
outputs=[result_gif],
|
||||
).success(lambda: print('Image_to_3D Done ...'))
|
||||
|
||||
#===============================================================
|
||||
|
||||
gr.Markdown(CONST_CITATION)
|
||||
demo.queue(max_size=CONST_MAX_QUEUE)
|
||||
demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
|
||||
|
BIN
assets/logo.png
Normal file
After Width: | Height: | Size: 306 KiB |
BIN
assets/overview_3.png
Normal file
After Width: | Height: | Size: 264 KiB |
BIN
assets/radar.png
Normal file
After Width: | Height: | Size: 119 KiB |
BIN
assets/runtime.png
Normal file
After Width: | Height: | Size: 38 KiB |
BIN
assets/teaser.png
Normal file
After Width: | Height: | Size: 3.0 MiB |
BIN
demos/example_000.png
Normal file
After Width: | Height: | Size: 643 KiB |
BIN
demos/example_001.png
Normal file
After Width: | Height: | Size: 798 KiB |
BIN
demos/example_002.png
Normal file
After Width: | Height: | Size: 331 KiB |
BIN
demos/example_003.png
Normal file
After Width: | Height: | Size: 1.0 MiB |
5
demos/example_list.txt
Normal file
@ -0,0 +1,5 @@
|
||||
一片绿色的树叶在白色背景上居中展现,清晰的纹理
|
||||
一只棕白相间的仓鼠,站在白色背景前。照片采用居中构图方式,卡通风格
|
||||
一盆绿色植物生长在红色花盆中,居中,写实
|
||||
a pot of green plants grows in a red flower pot.
|
||||
a lovely rabbit eating carrots
|
15
env_install.sh
Normal file
@ -0,0 +1,15 @@
|
||||
# python3.9 test success
|
||||
|
||||
pip install torch==2.2.0 torchvision==0.17.0 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
pip install diffusers transformers
|
||||
|
||||
pip install rembg tqdm omegaconf matplotlib opencv-python imageio jaxtyping einops
|
||||
|
||||
pip install SentencePiece accelerate trimesh PyMCubes xatlas libigl
|
||||
|
||||
pip install git+https://github.com/facebookresearch/pytorch3d
|
||||
|
||||
pip install git+https://github.com/NVlabs/nvdiffrast
|
||||
|
||||
pip install open3d
|
28
infer/__init__.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from .utils import seed_everything, timing_decorator, auto_amp_inference
|
||||
from .rembg import Removebg
|
||||
from .text_to_image import Text2Image
|
||||
from .image_to_views import Image2Views, save_gif
|
||||
from .views_to_mesh import Views2Mesh
|
||||
from .gif_render import GifRenderer
|
55
infer/gif_render.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
from svrm.ldm.vis_util import render
|
||||
from .utils import seed_everything, timing_decorator
|
||||
|
||||
class GifRenderer():
|
||||
'''
|
||||
render frame(s) of mesh using pytorch3d
|
||||
'''
|
||||
def __init__(self, device="cuda:0"):
|
||||
self.device = device
|
||||
|
||||
@timing_decorator("gif render")
|
||||
def __call__(
|
||||
self,
|
||||
obj_filename,
|
||||
elev=0,
|
||||
azim=0,
|
||||
resolution=512,
|
||||
gif_dst_path='',
|
||||
n_views=120,
|
||||
fps=30,
|
||||
rgb=True
|
||||
):
|
||||
render(
|
||||
obj_filename,
|
||||
elev=elev,
|
||||
azim=azim,
|
||||
resolution=resolution,
|
||||
gif_dst_path=gif_dst_path,
|
||||
n_views=n_views,
|
||||
fps=fps,
|
||||
device=self.device,
|
||||
rgb=rgb
|
||||
)
|
81
infer/image_to_views.py
Normal file
@ -0,0 +1,81 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
from PIL import Image, ImageSequence
|
||||
|
||||
from .utils import seed_everything, timing_decorator, auto_amp_inference
|
||||
from .utils import get_parameter_number, set_parameter_grad_false
|
||||
from mvd.hunyuan3d_mvd_std_pipeline import HunYuan3D_MVD_Std_Pipeline
|
||||
from mvd.hunyuan3d_mvd_lite_pipeline import Hunyuan3d_MVD_Lite_Pipeline
|
||||
|
||||
|
||||
def save_gif(pils, save_path, df=False):
|
||||
# save a list of PIL.Image to gif
|
||||
spf = 4000 / len(pils)
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
pils[0].save(save_path, format="GIF", save_all=True, append_images=pils[1:], duration=spf, loop=0)
|
||||
return save_path
|
||||
|
||||
|
||||
class Image2Views():
|
||||
def __init__(self, device="cuda:0", use_lite=False):
|
||||
self.device = device
|
||||
if use_lite:
|
||||
self.pipe = Hunyuan3d_MVD_Lite_Pipeline.from_pretrained(
|
||||
"./weights/mvd_lite",
|
||||
torch_dtype = torch.float16,
|
||||
use_safetensors = True,
|
||||
)
|
||||
else:
|
||||
self.pipe = HunYuan3D_MVD_Std_Pipeline.from_pretrained(
|
||||
"./weights/mvd_std",
|
||||
torch_dtype = torch.float16,
|
||||
use_safetensors = True,
|
||||
)
|
||||
self.pipe = self.pipe.to(device)
|
||||
self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
|
||||
set_parameter_grad_false(self.pipe.unet)
|
||||
print('image2views unet model', get_parameter_number(self.pipe.unet))
|
||||
|
||||
@torch.no_grad()
|
||||
@timing_decorator("image to views")
|
||||
@auto_amp_inference
|
||||
def __call__(self, pil_img, seed=0, steps=50, guidance_scale=2.0, guidance_curve=lambda t:2.0):
|
||||
seed_everything(seed)
|
||||
generator = torch.Generator(device=self.device)
|
||||
res_img = self.pipe(pil_img,
|
||||
num_inference_steps=steps,
|
||||
guidance_scale=guidance_scale,
|
||||
guidance_curve=guidance_curve,
|
||||
generat=generator).images
|
||||
show_image = rearrange(np.asarray(res_img[0], dtype=np.uint8), '(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
||||
pils = [res_img[1]]+[Image.fromarray(show_image[idx]) for idx in self.order]
|
||||
torch.cuda.empty_cache()
|
||||
return res_img, pils
|
||||
|
26
infer/rembg.py
Normal file
@ -0,0 +1,26 @@
|
||||
from rembg import remove, new_session
|
||||
from .utils import timing_decorator
|
||||
|
||||
class Removebg():
|
||||
def __init__(self, name="u2net"):
|
||||
'''
|
||||
name: rembg
|
||||
'''
|
||||
self.session = new_session(name)
|
||||
|
||||
@timing_decorator("remove background")
|
||||
def __call__(self, rgb_img, force=False):
|
||||
'''
|
||||
inputs:
|
||||
rgb_img: PIL.Image, with RGB mode expected
|
||||
force: bool, input is RGBA mode
|
||||
return:
|
||||
rgba_img: PIL.Image with RGBA mode
|
||||
'''
|
||||
if rgb_img.mode == "RGBA":
|
||||
if force:
|
||||
rgb_img = rgb_img.convert("RGB")
|
||||
else:
|
||||
return rgb_img
|
||||
rgba_img = remove(rgb_img, session=self.session)
|
||||
return rgba_img
|
80
infer/text_to_image.py
Normal file
@ -0,0 +1,80 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import torch
|
||||
from .utils import seed_everything, timing_decorator, auto_amp_inference
|
||||
from .utils import get_parameter_number, set_parameter_grad_false
|
||||
from diffusers import HunyuanDiTPipeline, AutoPipelineForText2Image
|
||||
|
||||
class Text2Image():
|
||||
def __init__(self, pretrain="weights/hunyuanDiT", device="cuda:0", save_memory=False):
|
||||
'''
|
||||
save_memory: if GPU memory is low, can set it
|
||||
'''
|
||||
self.save_memory = save_memory
|
||||
self.device = device
|
||||
self.pipe = AutoPipelineForText2Image.from_pretrained(
|
||||
pretrain,
|
||||
torch_dtype = torch.float16,
|
||||
enable_pag = True,
|
||||
pag_applied_layers = ["blocks.(16|17|18|19)"]
|
||||
)
|
||||
set_parameter_grad_false(self.pipe.transformer)
|
||||
print('text2image transformer model', get_parameter_number(self.pipe.transformer))
|
||||
if not save_memory:
|
||||
self.pipe = self.pipe.to(device)
|
||||
self.neg_txt = "文本,特写,裁剪,出框,最差质量,低质量,JPEG伪影,PGLY,重复,病态,残缺,多余的手指,变异的手," \
|
||||
"画得不好的手,画得不好的脸,变异,畸形,模糊,脱水,糟糕的解剖学,糟糕的比例,多余的肢体,克隆的脸," \
|
||||
"毁容,恶心的比例,畸形的肢体,缺失的手臂,缺失的腿,额外的手臂,额外的腿,融合的手指,手指太多,长脖子"
|
||||
|
||||
@torch.no_grad()
|
||||
@timing_decorator('text to image')
|
||||
@auto_amp_inference
|
||||
def __call__(self, *args, **kwargs):
|
||||
if self.save_memory:
|
||||
self.pipe = self.pipe.to(self.device)
|
||||
torch.cuda.empty_cache()
|
||||
res = self.call(*args, **kwargs)
|
||||
self.pipe = self.pipe.to("cpu")
|
||||
else:
|
||||
res = self.call(*args, **kwargs)
|
||||
torch.cuda.empty_cache()
|
||||
return res
|
||||
|
||||
def call(self, prompt, seed=0, steps=25):
|
||||
'''
|
||||
inputs:
|
||||
prompr: str
|
||||
seed: int
|
||||
steps: int
|
||||
return:
|
||||
rgb: PIL.Image
|
||||
'''
|
||||
prompt = prompt + ",白色背景,3D风格,最佳质量"
|
||||
seed_everything(seed)
|
||||
generator = torch.Generator(device=self.device)
|
||||
if seed is not None: generator = generator.manual_seed(int(seed))
|
||||
rgb = self.pipe(prompt=prompt, negative_prompt=self.neg_txt, num_inference_steps=steps,
|
||||
pag_scale=1.3, width=1024, height=1024, generator=generator, return_dict=False)[0][0]
|
||||
torch.cuda.empty_cache()
|
||||
return rgb
|
||||
|
77
infer/utils.py
Normal file
@ -0,0 +1,77 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.cuda.amp import autocast, GradScaler
|
||||
from functools import wraps
|
||||
|
||||
def seed_everything(seed):
|
||||
'''
|
||||
seed everthing
|
||||
'''
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
||||
|
||||
def timing_decorator(category: str):
|
||||
'''
|
||||
timing_decorator: record time
|
||||
'''
|
||||
def decorator(func):
|
||||
func.call_count = 0
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
elapsed_time = end_time - start_time
|
||||
func.call_count += 1
|
||||
print(f"[HunYuan3D]-[{category}], cost time: {elapsed_time:.4f}s") # huiwen
|
||||
return result
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
def auto_amp_inference(func):
|
||||
'''
|
||||
with torch.cuda.amp.autocast()"
|
||||
xxx
|
||||
'''
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
with autocast():
|
||||
output = func(*args, **kwargs)
|
||||
return output
|
||||
return wrapper
|
||||
|
||||
def get_parameter_number(model):
|
||||
total_num = sum(p.numel() for p in model.parameters())
|
||||
trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
return {'Total': total_num, 'Trainable': trainable_num}
|
||||
|
||||
def set_parameter_grad_false(model):
|
||||
for p in model.parameters():
|
||||
p.requires_grad = False
|
94
infer/views_to_mesh.py
Normal file
@ -0,0 +1,94 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import random
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
from PIL import Image, ImageSequence
|
||||
|
||||
from .utils import seed_everything, timing_decorator, auto_amp_inference
|
||||
from .utils import get_parameter_number, set_parameter_grad_false
|
||||
from svrm.predictor import MV23DPredictor
|
||||
|
||||
|
||||
class Views2Mesh():
|
||||
def __init__(self, mv23d_cfg_path, mv23d_ckt_path, device="cuda:0", use_lite=False):
|
||||
'''
|
||||
mv23d_cfg_path: config yaml file
|
||||
mv23d_ckt_path: path to ckpt
|
||||
use_lite:
|
||||
'''
|
||||
self.mv23d_predictor = MV23DPredictor(mv23d_ckt_path, mv23d_cfg_path, device=device)
|
||||
self.mv23d_predictor.model.eval()
|
||||
self.order = [0, 1, 2, 3, 4, 5] if use_lite else [0, 2, 4, 5, 3, 1]
|
||||
set_parameter_grad_false(self.mv23d_predictor.model)
|
||||
print('view2mesh model', get_parameter_number(self.mv23d_predictor.model))
|
||||
|
||||
@torch.no_grad()
|
||||
@timing_decorator("views to mesh")
|
||||
@auto_amp_inference
|
||||
def __call__(
|
||||
self,
|
||||
views_pil=None,
|
||||
cond_pil=None,
|
||||
gif_pil=None,
|
||||
seed=0,
|
||||
target_face_count = 10000,
|
||||
do_texture_mapping = True,
|
||||
save_folder='./outputs/test'
|
||||
):
|
||||
'''
|
||||
can set views_pil, cond_pil simutaously or set gif_pil only
|
||||
seed: int
|
||||
target_face_count: int
|
||||
save_folder: path to save mesh files
|
||||
'''
|
||||
save_dir = save_folder
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
if views_pil is not None and cond_pil is not None:
|
||||
show_image = rearrange(np.asarray(views_pil, dtype=np.uint8),
|
||||
'(n h) (m w) c -> (n m) h w c', n=3, m=2)
|
||||
views = [Image.fromarray(show_image[idx]) for idx in self.order]
|
||||
image_list = [cond_pil]+ views
|
||||
image_list = [img.convert('RGB') for img in image_list]
|
||||
elif gif_pil is not None:
|
||||
image_list = [img.convert('RGB') for img in ImageSequence.Iterator(gif_pil)]
|
||||
|
||||
image_input = image_list[0]
|
||||
image_list = image_list[1:] + image_list[:1]
|
||||
|
||||
seed_everything(seed)
|
||||
self.mv23d_predictor.predict(
|
||||
image_list,
|
||||
save_dir = save_dir,
|
||||
image_input = image_input,
|
||||
target_face_count = target_face_count,
|
||||
do_texture_mapping = do_texture_mapping
|
||||
)
|
||||
torch.cuda.empty_cache()
|
||||
return save_dir
|
||||
|
146
main.py
Normal file
@ -0,0 +1,146 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.l
|
||||
|
||||
import os
|
||||
import torch
|
||||
from PIL import Image
|
||||
import argparse
|
||||
|
||||
from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--use_lite", default=False, action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text2image_path", default="weights/hunyuanDiT", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_folder", default="./outputs/test/", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_prompt", default="", type=str,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_prompt", default="", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device", default="cuda:0", type=str
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t2i_seed", default=0, type=int
|
||||
)
|
||||
parser.add_argument(
|
||||
"--t2i_steps", default=25, type=int
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_seed", default=0, type=int
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_steps", default=50, type=int
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_faces_num", default=80000, type=int,
|
||||
help="max num of face, suggest 80000 for effect, 10000 for speed"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_memory", default=False, action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_texture_mapping", default=False, action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_render", default=False, action="store_true"
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
assert not (args.text_prompt and args.image_prompt), "Text and image can only be given to one"
|
||||
assert args.text_prompt or args.image_prompt, "Text and image can only be given to one"
|
||||
|
||||
# init model
|
||||
rembg_model = Removebg()
|
||||
image_to_views_model = Image2Views(device=args.device, use_lite=args.use_lite)
|
||||
views_to_mesh_model = Views2Mesh(args.mv23d_cfg_path, args.mv23d_ckt_path, args.device, use_lite=args.use_lite)
|
||||
if args.text_prompt:
|
||||
text_to_image_model = Text2Image(
|
||||
pretrain = args.text2image_path,
|
||||
device = args.device,
|
||||
save_memory = args.save_memory
|
||||
)
|
||||
if args.do_render:
|
||||
gif_renderer = GifRenderer(device=args.device)
|
||||
|
||||
# ---- ----- ---- ---- ---- ----
|
||||
|
||||
os.makedirs(args.save_folder, exist_ok=True)
|
||||
|
||||
# stage 1, text to image
|
||||
if args.text_prompt:
|
||||
res_rgb_pil = text_to_image_model(
|
||||
args.text_prompt,
|
||||
seed=args.t2i_seed,
|
||||
steps=args.t2i_steps
|
||||
)
|
||||
res_rgb_pil.save(os.path.join(args.save_folder, "img.jpg"))
|
||||
elif args.image_prompt:
|
||||
res_rgb_pil = Image.open(args.image_prompt)
|
||||
|
||||
# stage 2, remove back ground
|
||||
res_rgba_pil = rembg_model(res_rgb_pil)
|
||||
res_rgb_pil.save(os.path.join(args.save_folder, "img_nobg.png"))
|
||||
|
||||
# stage 3, image to views
|
||||
(views_grid_pil, cond_img), view_pil_list = image_to_views_model(
|
||||
res_rgba_pil,
|
||||
seed = args.gen_seed,
|
||||
steps = args.gen_steps
|
||||
)
|
||||
views_grid_pil.save(os.path.join(args.save_folder, "views.jpg"))
|
||||
|
||||
# stage 4, views to mesh
|
||||
views_to_mesh_model(
|
||||
views_grid_pil,
|
||||
cond_img,
|
||||
seed = args.gen_seed,
|
||||
target_face_count = args.max_faces_num,
|
||||
save_folder = args.save_folder,
|
||||
do_texture_mapping = args.do_texture_mapping
|
||||
)
|
||||
|
||||
# stage 5, render gif
|
||||
if args.do_render:
|
||||
gif_renderer(
|
||||
os.path.join(args.save_folder, 'mesh.obj'),
|
||||
gif_dst_path = os.path.join(args.save_folder, 'output.gif'),
|
||||
)
|
0
mvd/__init__.py
Normal file
493
mvd/hunyuan3d_mvd_lite_pipeline.py
Normal file
@ -0,0 +1,493 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import math
|
||||
import numpy
|
||||
import torch
|
||||
import inspect
|
||||
import warnings
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
import torch.nn.functional as F
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import FrozenDict
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers import DDPMScheduler, EulerAncestralDiscreteScheduler, ImagePipelineOutput
|
||||
from diffusers.loaders import (
|
||||
FromSingleFileMixin,
|
||||
LoraLoaderMixin,
|
||||
TextualInversionLoaderMixin
|
||||
)
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection
|
||||
)
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
AttnProcessor2_0
|
||||
)
|
||||
|
||||
from .utils import to_rgb_image, white_out_background, recenter_img
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from here import Hunyuan3d_MVD_Qing_Pipeline
|
||||
|
||||
>>> pipe = Hunyuan3d_MVD_Qing_Pipeline.from_pretrained(
|
||||
... "Tencent-Hunyuan-3D/MVD-Qing", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> img = Image.open("demo.png")
|
||||
>>> res_img = pipe(img).images[0]
|
||||
"""
|
||||
|
||||
def unscale_latents(latents): return latents / 0.75 + 0.22
|
||||
def unscale_image (image ): return image / 0.50 * 0.80
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
|
||||
class ReferenceOnlyAttnProc(torch.nn.Module):
|
||||
# reference attention
|
||||
def __init__(self, chained_proc, enabled=False, name=None):
|
||||
super().__init__()
|
||||
self.enabled = enabled
|
||||
self.chained_proc = chained_proc
|
||||
self.name = name
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
|
||||
if encoder_hidden_states is None: encoder_hidden_states = hidden_states
|
||||
if self.enabled:
|
||||
if mode == 'w':
|
||||
ref_dict[self.name] = encoder_hidden_states
|
||||
elif mode == 'r':
|
||||
encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
|
||||
res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
return res
|
||||
|
||||
|
||||
# class RowWiseAttnProcessor2_0:
|
||||
# def __call__(self, attn,
|
||||
# hidden_states,
|
||||
# encoder_hidden_states=None,
|
||||
# attention_mask=None,
|
||||
# temb=None,
|
||||
# num_views=6,
|
||||
# *args,
|
||||
# **kwargs):
|
||||
# residual = hidden_states
|
||||
# if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
# input_ndim = hidden_states.ndim
|
||||
# if input_ndim == 4:
|
||||
# batch_size, channel, height, width = hidden_states.shape
|
||||
# hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
# if encoder_hidden_states is None:
|
||||
# batch_size, sequence_length, _ = hidden_states.shape
|
||||
# else:
|
||||
# batch_size, sequence_length, _ = encoder_hidden_states.shape
|
||||
|
||||
# if attention_mask is not None:
|
||||
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
# if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
# query = attn.to_q(hidden_states)
|
||||
# if encoder_hidden_states is None: encoder_hidden_states = hidden_states
|
||||
# elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# # encoder_hidden_states [B, 6hw+hw, C] if ref att
|
||||
# key = attn.to_k(encoder_hidden_states) # [B, Vhw+hw, C]
|
||||
# value = attn.to_v(encoder_hidden_states) # [B, Vhw+hw, C]
|
||||
|
||||
# mv_flag = hidden_states.shape[1] < encoder_hidden_states.shape[1] and encoder_hidden_states.shape[1] != 77
|
||||
# if mv_flag:
|
||||
# target_size = int(math.sqrt(hidden_states.shape[1] // num_views))
|
||||
# assert target_size ** 2 * num_views == hidden_states.shape[1]
|
||||
|
||||
# gen_key = key[:, :num_views*target_size*target_size, :]
|
||||
# ref_key = key[:, num_views*target_size*target_size:, :]
|
||||
# gen_value = value[:, :num_views*target_size*target_size, :]
|
||||
# ref_value = value[:, num_views*target_size*target_size:, :]
|
||||
|
||||
# # rowwise attention
|
||||
# query, gen_key, gen_value = \
|
||||
# rearrange( query, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
|
||||
# v1=num_views//2, v2=2, h=target_size, w=target_size), \
|
||||
# rearrange( gen_key, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
|
||||
# v1=num_views//2, v2=2, h=target_size, w=target_size), \
|
||||
# rearrange(gen_value, "b (v1 h v2 w) c -> (b h) (v1 v2 w) c",
|
||||
# v1=num_views//2, v2=2, h=target_size, w=target_size)
|
||||
|
||||
# inner_dim = key.shape[-1]
|
||||
# ref_size = int(math.sqrt(ref_key.shape[1]))
|
||||
# ref_key_expanded = ref_key.view(batch_size, 1, ref_size * ref_size, inner_dim)
|
||||
# ref_key_expanded = ref_key_expanded.expand(-1, target_size, -1, -1).contiguous()
|
||||
# ref_key_expanded = ref_key_expanded.view(batch_size * target_size, ref_size * ref_size, inner_dim)
|
||||
# key = torch.cat([ gen_key, ref_key_expanded], dim=1)
|
||||
|
||||
# ref_value_expanded = ref_value.view(batch_size, 1, ref_size * ref_size, inner_dim)
|
||||
# ref_value_expanded = ref_value_expanded.expand(-1, target_size, -1, -1).contiguous()
|
||||
# ref_value_expanded = ref_value_expanded.view(batch_size * target_size, ref_size * ref_size, inner_dim)
|
||||
# value = torch.cat([gen_value, ref_value_expanded], dim=1)
|
||||
# h = target_size
|
||||
# else:
|
||||
# target_size = int(math.sqrt(hidden_states.shape[1]))
|
||||
# h = 1
|
||||
# num_views = 1
|
||||
|
||||
# inner_dim = key.shape[-1]
|
||||
# head_dim = inner_dim // attn.heads
|
||||
|
||||
# query = query.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# key = key.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
# value = value.view(batch_size * h, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# hidden_states = F.scaled_dot_product_attention(query, key, value,
|
||||
# attn_mask=attention_mask,
|
||||
# dropout_p=0.0,
|
||||
# is_causal=False)
|
||||
# hidden_states = hidden_states.transpose(1, 2).reshape(batch_size * h,
|
||||
# -1,
|
||||
# attn.heads * head_dim).to(query.dtype)
|
||||
# hidden_states = attn.to_out[1](attn.to_out[0](hidden_states))
|
||||
|
||||
# if mv_flag: hidden_states = rearrange(hidden_states, "(b h) (v1 v2 w) c -> b (v1 h v2 w) c",
|
||||
# b=batch_size, v1=num_views//2,
|
||||
# v2=2, h=target_size, w=target_size)
|
||||
|
||||
# if input_ndim == 4:
|
||||
# hidden_states = hidden_states.transpose(-1, -2)
|
||||
# hidden_states = hidden_states.reshape(batch_size,
|
||||
# channel,
|
||||
# target_size,
|
||||
# target_size)
|
||||
# if attn.residual_connection: hidden_states = hidden_states + residual
|
||||
# hidden_states = hidden_states / attn.rescale_output_factor
|
||||
# return hidden_states
|
||||
|
||||
|
||||
class RefOnlyNoisedUNet(torch.nn.Module):
|
||||
def __init__(self, unet, train_sched, val_sched):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.train_sched = train_sched
|
||||
self.val_sched = val_sched
|
||||
|
||||
unet_lora_attn_procs = dict()
|
||||
for name, _ in unet.attn_processors.items():
|
||||
unet_lora_attn_procs[name] = ReferenceOnlyAttnProc(AttnProcessor2_0(),
|
||||
enabled=name.endswith("attn1.processor"),
|
||||
name=name)
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.unet, name)
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, *args, cross_attention_kwargs, **kwargs):
|
||||
cond_lat = cross_attention_kwargs['cond_lat']
|
||||
noise = torch.randn_like(cond_lat)
|
||||
if self.training:
|
||||
noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep)
|
||||
noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep)
|
||||
else:
|
||||
noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1))
|
||||
noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
|
||||
|
||||
ref_dict = {}
|
||||
self.unet(noisy_cond_lat,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
*args,
|
||||
cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict),
|
||||
**kwargs)
|
||||
return self.unet(sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
*args,
|
||||
cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict),
|
||||
**kwargs)
|
||||
|
||||
|
||||
class Hunyuan3d_MVD_Lite_Pipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
vision_encoder: CLIPVisionModelWithProjection,
|
||||
feature_extractor_clip: CLIPImageProcessor,
|
||||
feature_extractor_vae: CLIPImageProcessor,
|
||||
ramping_coefficients: Optional[list] = None,
|
||||
safety_checker=None,
|
||||
):
|
||||
DiffusionPipeline.__init__(self)
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
vision_encoder=vision_encoder,
|
||||
feature_extractor_vae=feature_extractor_vae,
|
||||
feature_extractor_clip=feature_extractor_clip)
|
||||
'''
|
||||
rewrite the stable diffusion pipeline
|
||||
vae: vae
|
||||
unet: unet
|
||||
tokenizer: tokenizer
|
||||
scheduler: scheduler
|
||||
text_encoder: text_encoder
|
||||
vision_encoder: vision_encoder
|
||||
feature_extractor_vae: feature_extractor_vae
|
||||
feature_extractor_clip: feature_extractor_clip
|
||||
'''
|
||||
self.register_to_config(ramping_coefficients=ramping_coefficients)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
extra_step_kwargs = {}
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_eta: extra_step_kwargs["eta"] = eta
|
||||
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator: extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt=None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
):
|
||||
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = text_inputs.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)[0]
|
||||
|
||||
if self.text_encoder is not None:
|
||||
prompt_embeds_dtype = self.text_encoder.dtype
|
||||
elif self.unet is not None:
|
||||
prompt_embeds_dtype = self.unet.dtype
|
||||
else:
|
||||
prompt_embeds_dtype = prompt_embeds.dtype
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens: List[str]
|
||||
if negative_prompt is None: uncond_tokens = [""] * batch_size
|
||||
elif prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError()
|
||||
elif isinstance(negative_prompt, str): uncond_tokens = [negative_prompt]
|
||||
elif batch_size != len(negative_prompt): raise ValueError()
|
||||
else: uncond_tokens = negative_prompt
|
||||
if isinstance(self, TextualInversionLoaderMixin):
|
||||
uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
|
||||
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
|
||||
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
|
||||
attention_mask = uncond_input.attention_mask.to(device)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device), attention_mask=attention_mask)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_condition_image(self, image: torch.Tensor): return self.vae.encode(image).latent_dist.sample()
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, image=None,
|
||||
width=640,
|
||||
height=960,
|
||||
num_inference_steps=75,
|
||||
return_dict=True,
|
||||
generator=None,
|
||||
**kwargs):
|
||||
batch_size = 1
|
||||
num_images_per_prompt = 1
|
||||
output_type = 'pil'
|
||||
do_classifier_free_guidance = True
|
||||
guidance_rescale = 0.
|
||||
if isinstance(self.unet, UNet2DConditionModel):
|
||||
self.unet = RefOnlyNoisedUNet(self.unet, None, self.scheduler).eval()
|
||||
|
||||
cond_image = recenter_img(image)
|
||||
cond_image = to_rgb_image(image)
|
||||
image = cond_image
|
||||
image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values
|
||||
image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values
|
||||
image_1 = image_1.to(device=self.vae.device, dtype=self.vae.dtype)
|
||||
image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
cond_lat = self.encode_condition_image(image_1)
|
||||
negative_lat = self.encode_condition_image(torch.zeros_like(image_1))
|
||||
cond_lat = torch.cat([negative_lat, cond_lat])
|
||||
cross_attention_kwargs = dict(cond_lat=cond_lat)
|
||||
|
||||
global_embeds = self.vision_encoder(image_2, output_hidden_states=False).image_embeds.unsqueeze(-2)
|
||||
encoder_hidden_states = self._encode_prompt('', self.device, num_images_per_prompt, False)
|
||||
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
|
||||
prompt_embeds = torch.cat([encoder_hidden_states, encoder_hidden_states + global_embeds * ramp])
|
||||
|
||||
device = self._execution_device
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
None)
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, 0.0)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
|
||||
# set adaptive cfg
|
||||
# the image order is:
|
||||
# [0, 60,
|
||||
# 120, 180,
|
||||
# 240, 300]
|
||||
# the cfg is set as 3, 2.5, 2, 1.5
|
||||
|
||||
tmp_guidance_scale = torch.ones_like(latents)
|
||||
tmp_guidance_scale[:, :, :40, :40] = 3
|
||||
tmp_guidance_scale[:, :, :40, 40:] = 2.5
|
||||
tmp_guidance_scale[:, :, 40:80, :40] = 2
|
||||
tmp_guidance_scale[:, :, 40:80, 40:] = 1.5
|
||||
tmp_guidance_scale[:, :, 80:120, :40] = 2
|
||||
tmp_guidance_scale[:, :, 80:120, 40:] = 2.5
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
noise_pred = self.unet(latent_model_input, t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
return_dict=False)[0]
|
||||
|
||||
adaptive_guidance_scale = (2 + 16 * (t / 1000) ** 5) / 3
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + \
|
||||
tmp_guidance_scale * adaptive_guidance_scale * \
|
||||
(noise_pred_text - noise_pred_uncond)
|
||||
|
||||
if do_classifier_free_guidance and guidance_rescale > 0.0:
|
||||
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
if i==len(timesteps)-1 or ((i+1)>num_warmup_steps and (i+1)%self.scheduler.order==0):
|
||||
progress_bar.update()
|
||||
|
||||
latents = unscale_latents(latents)
|
||||
image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0])
|
||||
image = self.image_processor.postprocess(image, output_type='pil')[0]
|
||||
image = [image, cond_image]
|
||||
return ImagePipelineOutput(images=image) if return_dict else (image,)
|
||||
|
471
mvd/hunyuan3d_mvd_std_pipeline.py
Normal file
@ -0,0 +1,471 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import diffusers
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
AttnProcessor2_0
|
||||
)
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
UNet2DConditionModel,
|
||||
ImagePipelineOutput
|
||||
)
|
||||
import transformers
|
||||
from transformers import (
|
||||
CLIPImageProcessor,
|
||||
CLIPTextModel,
|
||||
CLIPTokenizer,
|
||||
CLIPVisionModelWithProjection,
|
||||
CLIPTextModelWithProjection
|
||||
)
|
||||
|
||||
from .utils import to_rgb_image, white_out_background, recenter_img
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import Hunyuan3d_MVD_XL_Pipeline
|
||||
|
||||
>>> pipe = Hunyuan3d_MVD_XL_Pipeline.from_pretrained(
|
||||
... "Tencent-Hunyuan-3D/MVD-XL", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> img = Image.open("demo.png")
|
||||
>>> res_img = pipe(img).images[0]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
|
||||
def scale_latents(latents): return (latents - 0.22) * 0.75
|
||||
def unscale_latents(latents): return (latents / 0.75) + 0.22
|
||||
def scale_image(image): return (image - 0.5) / 0.5
|
||||
def scale_image_2(image): return (image * 0.5) / 0.8
|
||||
def unscale_image(image): return (image * 0.5) + 0.5
|
||||
def unscale_image_2(image): return (image * 0.8) / 0.5
|
||||
|
||||
|
||||
|
||||
|
||||
class ReferenceOnlyAttnProc(torch.nn.Module):
|
||||
def __init__(self, chained_proc, enabled=False, name=None):
|
||||
super().__init__()
|
||||
self.enabled = enabled
|
||||
self.chained_proc = chained_proc
|
||||
self.name = name
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, mode="w", ref_dict=None):
|
||||
encoder_hidden_states = hidden_states if encoder_hidden_states is None else encoder_hidden_states
|
||||
if self.enabled:
|
||||
if mode == 'w': ref_dict[self.name] = encoder_hidden_states
|
||||
elif mode == 'r': encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1)
|
||||
else: raise Exception(f"mode should not be {mode}")
|
||||
return self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask)
|
||||
|
||||
|
||||
class RefOnlyNoisedUNet(torch.nn.Module):
|
||||
def __init__(self, unet, scheduler) -> None:
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
self.scheduler = scheduler
|
||||
|
||||
unet_attn_procs = dict()
|
||||
for name, _ in unet.attn_processors.items():
|
||||
if torch.__version__ >= '2.0': default_attn_proc = AttnProcessor2_0()
|
||||
elif is_xformers_available(): default_attn_proc = XFormersAttnProcessor()
|
||||
else: default_attn_proc = AttnProcessor()
|
||||
unet_attn_procs[name] = ReferenceOnlyAttnProc(
|
||||
default_attn_proc, enabled=name.endswith("attn1.processor"), name=name
|
||||
)
|
||||
unet.set_attn_processor(unet_attn_procs)
|
||||
|
||||
def __getattr__(self, name: str):
|
||||
try:
|
||||
return super().__getattr__(name)
|
||||
except AttributeError:
|
||||
return getattr(self.unet, name)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
down_block_res_samples: Optional[Tuple[torch.Tensor]] = None,
|
||||
mid_block_res_sample: Optional[Tuple[torch.Tensor]] = None,
|
||||
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
dtype = self.unet.dtype
|
||||
|
||||
# cond_lat add same level noise
|
||||
cond_lat = cross_attention_kwargs['cond_lat']
|
||||
noise = torch.randn_like(cond_lat)
|
||||
|
||||
noisy_cond_lat = self.scheduler.add_noise(cond_lat, noise, timestep.reshape(-1))
|
||||
noisy_cond_lat = self.scheduler.scale_model_input(noisy_cond_lat, timestep.reshape(-1))
|
||||
|
||||
ref_dict = {}
|
||||
|
||||
_ = self.unet(
|
||||
noisy_cond_lat,
|
||||
timestep,
|
||||
encoder_hidden_states = encoder_hidden_states,
|
||||
class_labels = class_labels,
|
||||
cross_attention_kwargs = dict(mode="w", ref_dict=ref_dict),
|
||||
added_cond_kwargs = added_cond_kwargs,
|
||||
return_dict = return_dict,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
res = self.unet(
|
||||
sample,
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
class_labels=class_labels,
|
||||
cross_attention_kwargs = dict(mode="r", ref_dict=ref_dict),
|
||||
down_block_additional_residuals = [
|
||||
sample.to(dtype=dtype) for sample in down_block_res_samples
|
||||
] if down_block_res_samples is not None else None,
|
||||
mid_block_additional_residual = (
|
||||
mid_block_res_sample.to(dtype=dtype)
|
||||
if mid_block_res_sample is not None else None),
|
||||
added_cond_kwargs = added_cond_kwargs,
|
||||
return_dict = return_dict,
|
||||
**kwargs
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
class HunYuan3D_MVD_Std_Pipeline(diffusers.DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
feature_extractor_vae: CLIPImageProcessor,
|
||||
vision_processor: CLIPImageProcessor,
|
||||
vision_encoder: CLIPVisionModelWithProjection,
|
||||
vision_encoder_2: CLIPVisionModelWithProjection,
|
||||
ramping_coefficients: Optional[list] = None,
|
||||
add_watermarker: Optional[bool] = None,
|
||||
safety_checker = None,
|
||||
):
|
||||
DiffusionPipeline.__init__(self)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae, unet=unet, scheduler=scheduler, safety_checker=None, feature_extractor_vae=feature_extractor_vae,
|
||||
vision_processor=vision_processor, vision_encoder=vision_encoder, vision_encoder_2=vision_encoder_2,
|
||||
)
|
||||
self.register_to_config( ramping_coefficients = ramping_coefficients)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
self.default_sample_size = self.unet.config.sample_size
|
||||
self.watermark = None
|
||||
self.prepare_init = False
|
||||
|
||||
def prepare(self):
|
||||
assert isinstance(self.unet, UNet2DConditionModel), "unet should be UNet2DConditionModel"
|
||||
self.unet = RefOnlyNoisedUNet(self.unet, self.scheduler).eval()
|
||||
self.prepare_init = True
|
||||
|
||||
def encode_image(self, image: torch.Tensor, scale_factor: bool = False):
|
||||
latent = self.vae.encode(image).latent_dist.sample()
|
||||
return (latent * self.vae.config.scaling_factor) if scale_factor else latent
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
|
||||
# scale the initial noise by the standard deviation required by the scheduler
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
|
||||
):
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
|
||||
passed_add_embed_dim = (
|
||||
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
||||
)
|
||||
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, " \
|
||||
f"but a vector of {passed_add_embed_dim} was created. The model has an incorrect config." \
|
||||
f" Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
||||
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
||||
# and should be between [0, 1]
|
||||
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
extra_step_kwargs = {}
|
||||
if accepts_eta: extra_step_kwargs["eta"] = eta
|
||||
|
||||
# check if the scheduler accepts generator
|
||||
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_generator: extra_step_kwargs["generator"] = generator
|
||||
return extra_step_kwargs
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
image: Image.Image = None,
|
||||
guidance_scale = 2.0,
|
||||
output_type: Optional[str] = "pil",
|
||||
num_inference_steps: int = 50,
|
||||
return_dict: bool = True,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
latent: torch.Tensor = None,
|
||||
guidance_curve = None,
|
||||
**kwargs
|
||||
):
|
||||
if not self.prepare_init:
|
||||
self.prepare()
|
||||
|
||||
here = dict(device=self.vae.device, dtype=self.vae.dtype)
|
||||
|
||||
batch_size = 1
|
||||
num_images_per_prompt = 1
|
||||
width, height = 512 * 2, 512 * 3
|
||||
target_size = original_size = (height, width)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._cross_attention_kwargs = cross_attention_kwargs
|
||||
self._interrupt = False
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
self.vae.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=latent,
|
||||
)
|
||||
|
||||
# Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
|
||||
# Prepare added time ids & embeddings
|
||||
text_encoder_projection_dim = 1280
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=self.vae.dtype,
|
||||
text_encoder_projection_dim=text_encoder_projection_dim,
|
||||
)
|
||||
negative_add_time_ids = add_time_ids
|
||||
|
||||
# hw: preprocess
|
||||
cond_image = recenter_img(image)
|
||||
cond_image = to_rgb_image(image)
|
||||
image_vae = self.feature_extractor_vae(images=cond_image, return_tensors="pt").pixel_values.to(**here)
|
||||
image_clip = self.vision_processor(images=cond_image, return_tensors="pt").pixel_values.to(**here)
|
||||
|
||||
# hw: get cond_lat from cond_img using vae
|
||||
cond_lat = self.encode_image(image_vae, scale_factor=False)
|
||||
negative_lat = self.encode_image(torch.zeros_like(image_vae), scale_factor=False)
|
||||
cond_lat = torch.cat([negative_lat, cond_lat])
|
||||
|
||||
# hw: get visual global embedding using clip
|
||||
global_embeds_1 = self.vision_encoder(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
|
||||
global_embeds_2 = self.vision_encoder_2(image_clip, output_hidden_states=False).image_embeds.unsqueeze(-2)
|
||||
global_embeds = torch.concat([global_embeds_1, global_embeds_2], dim=-1)
|
||||
|
||||
ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1)
|
||||
prompt_embeds = self.uc_text_emb.to(**here)
|
||||
pooled_prompt_embeds = self.uc_text_emb_2.to(**here)
|
||||
|
||||
prompt_embeds = prompt_embeds + global_embeds * ramp
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
||||
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds.to(device)
|
||||
add_text_embeds = add_text_embeds.to(device)
|
||||
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
timestep_cond = None
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
if guidance_curve is None:
|
||||
guidance_curve = lambda t: guidance_scale
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# predict the noise residual
|
||||
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep_cond=timestep_cond,
|
||||
cross_attention_kwargs=dict(cond_lat=cond_lat),
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
|
||||
# cur_guidance_scale = self.guidance_scale
|
||||
cur_guidance_scale = guidance_curve(t) # 1.5 + 2.5 * ((t/1000)**2)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + cur_guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# cur_guidance_scale_topleft = (cur_guidance_scale - 1.0) * 4 + 1.0
|
||||
# noise_pred_top_left = noise_pred_uncond +
|
||||
# cur_guidance_scale_topleft * (noise_pred_text - noise_pred_uncond)
|
||||
# _, _, h, w = noise_pred.shape
|
||||
# noise_pred[:, :, :h//3, :w//2] = noise_pred_top_left[:, :, :h//3, :w//2]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
latents = unscale_latents(latents)
|
||||
|
||||
if output_type=="latent":
|
||||
image = latents
|
||||
else:
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = unscale_image(unscale_image_2(image)).clamp(0, 1)
|
||||
image = [
|
||||
Image.fromarray((image[0]*255+0.5).clamp_(0, 255).permute(1, 2, 0).cpu().numpy().astype("uint8")),
|
||||
# self.image_processor.postprocess(image, output_type=output_type)[0],
|
||||
cond_image.resize((512, 512))
|
||||
]
|
||||
|
||||
if not return_dict: return (image,)
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
def save_pretrained(self, save_directory):
|
||||
# uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
|
||||
super().save_pretrained(save_directory)
|
||||
torch.save(self.uc_text_emb, os.path.join(save_directory, "uc_text_emb.pt"))
|
||||
torch.save(self.uc_text_emb_2, os.path.join(save_directory, "uc_text_emb_2.pt"))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
||||
# uc_text_emb.pt and uc_text_emb_2.pt are inferenced and saved in advance
|
||||
pipeline = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
||||
pipeline.uc_text_emb = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb.pt"))
|
||||
pipeline.uc_text_emb_2 = torch.load(os.path.join(pretrained_model_name_or_path, "uc_text_emb_2.pt"))
|
||||
return pipeline
|
85
mvd/utils.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
def to_rgb_image(maybe_rgba: Image.Image):
|
||||
'''
|
||||
convert a PIL.Image to rgb mode with white background
|
||||
maybe_rgba: PIL.Image
|
||||
return: PIL.Image
|
||||
'''
|
||||
if maybe_rgba.mode == 'RGB':
|
||||
return maybe_rgba
|
||||
elif maybe_rgba.mode == 'RGBA':
|
||||
rgba = maybe_rgba
|
||||
img = np.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=np.uint8)
|
||||
img = Image.fromarray(img, 'RGB')
|
||||
img.paste(rgba, mask=rgba.getchannel('A'))
|
||||
return img
|
||||
else:
|
||||
raise ValueError("Unsupported image type.", maybe_rgba.mode)
|
||||
|
||||
def white_out_background(pil_img, is_gray_fg=True):
|
||||
data = pil_img.getdata()
|
||||
new_data = []
|
||||
# convert fore-ground white to gray
|
||||
for r, g, b, a in data:
|
||||
if a < 16:
|
||||
new_data.append((255, 255, 255, 0)) # back-ground to be black
|
||||
else:
|
||||
is_white = is_gray_fg and (r>235) and (g>235) and (b>235)
|
||||
new_r = 235 if is_white else r
|
||||
new_g = 235 if is_white else g
|
||||
new_b = 235 if is_white else b
|
||||
new_data.append((new_r, new_g, new_b, a))
|
||||
pil_img.putdata(new_data)
|
||||
return pil_img
|
||||
|
||||
def recenter_img(img, size=512, color=(255,255,255)):
|
||||
img = white_out_background(img)
|
||||
mask = np.array(img)[..., 3]
|
||||
image = np.array(img)[..., :3]
|
||||
|
||||
H, W, C = image.shape
|
||||
coords = np.nonzero(mask)
|
||||
x_min, x_max = coords[0].min(), coords[0].max()
|
||||
y_min, y_max = coords[1].min(), coords[1].max()
|
||||
h = x_max - x_min
|
||||
w = y_max - y_min
|
||||
if h == 0 or w == 0: raise ValueError
|
||||
roi = image[x_min:x_max, y_min:y_max]
|
||||
|
||||
border_ratio = 0.15 # 0.2
|
||||
pad_h = int(h * border_ratio)
|
||||
pad_w = int(w * border_ratio)
|
||||
|
||||
result_tmp = np.full((h + pad_h, w + pad_w, C), color, dtype=np.uint8)
|
||||
result_tmp[pad_h // 2: pad_h // 2 + h, pad_w // 2: pad_w // 2 + w] = roi
|
||||
|
||||
cur_h, cur_w = result_tmp.shape[:2]
|
||||
side = max(cur_h, cur_w)
|
||||
result = np.full((side, side, C), color, dtype=np.uint8)
|
||||
result[(side-cur_h)//2:(side-cur_h)//2+cur_h, (side-cur_w)//2:(side - cur_w)//2+cur_w,:] = result_tmp
|
||||
result = Image.fromarray(result)
|
||||
return result.resize((size, size), Image.LANCZOS) if size else result
|
8
scripts/image_to_3d.sh
Normal file
@ -0,0 +1,8 @@
|
||||
# image to 3d
|
||||
|
||||
python main.py \
|
||||
--image_prompt ./demos/example_000.png \
|
||||
--save_folder ./outputs/test/ \
|
||||
--max_faces_num 90000 \
|
||||
--do_texture \
|
||||
--do_render
|
8
scripts/image_to_3d_demo.sh
Normal file
@ -0,0 +1,8 @@
|
||||
# image to 3d
|
||||
|
||||
python main.py \
|
||||
--image_prompt ./demos/example_000.png \
|
||||
--save_folder ./outputs/test/ \
|
||||
--max_faces_num 90000 \
|
||||
--do_texture_mapping \
|
||||
--do_render
|
6
scripts/image_to_3d_fast.sh
Normal file
@ -0,0 +1,6 @@
|
||||
# image to 3d fast
|
||||
python main.py \
|
||||
--image_prompt ./demos/example_000.png \
|
||||
--save_folder ./outputs/test/ \
|
||||
--max_faces_num 10000 \
|
||||
--use_lite
|
6
scripts/image_to_3d_fast_demo.sh
Normal file
@ -0,0 +1,6 @@
|
||||
# image to 3d fast
|
||||
python main.py \
|
||||
--image_prompt ./demos/example_000.png \
|
||||
--save_folder ./outputs/test/ \
|
||||
--max_faces_num 10000 \
|
||||
--use_lite
|
7
scripts/text_to_3d.sh
Normal file
@ -0,0 +1,7 @@
|
||||
# text to 3d fast
|
||||
python main.py \
|
||||
--text_prompt "a lovely cat" \
|
||||
--save_folder ./outputs/test/ \
|
||||
--max_faces_num 90000 \
|
||||
--do_texture \
|
||||
--do_render
|
7
scripts/text_to_3d_demo.sh
Normal file
@ -0,0 +1,7 @@
|
||||
# text to 3d fast
|
||||
python main.py \
|
||||
--text_prompt "a lovely rabbit" \
|
||||
--save_folder ./outputs/test/ \
|
||||
--max_faces_num 90000 \
|
||||
--do_texture_mapping \
|
||||
--do_render
|
6
scripts/text_to_3d_fast.sh
Normal file
@ -0,0 +1,6 @@
|
||||
# text to 3d fast
|
||||
python main.py \
|
||||
--text_prompt "一个广式茶杯" \
|
||||
--save_folder ./outputs/test/ \
|
||||
--max_faces_num 10000 \
|
||||
--use_lite
|
6
scripts/text_to_3d_fast_demo.sh
Normal file
@ -0,0 +1,6 @@
|
||||
# text to 3d fast
|
||||
python main.py \
|
||||
--text_prompt "一个广式茶杯" \
|
||||
--save_folder ./outputs/test/ \
|
||||
--max_faces_num 10000 \
|
||||
--use_lite
|
BIN
svrm/.DS_Store
vendored
Normal file
32
svrm/configs/2024-10-24T22-36-18-project.yaml
Normal file
@ -0,0 +1,32 @@
|
||||
model:
|
||||
base_learning_rate: 3.0e-05
|
||||
target: svrm.ldm.models.svrm.SVRMModel
|
||||
params:
|
||||
|
||||
img_encoder_config:
|
||||
target: svrm.ldm.modules.encoders.dinov2_mod.FrozenDinoV2ImageEmbedder
|
||||
params:
|
||||
version: dinov2_vitb14
|
||||
|
||||
img_to_triplane_config:
|
||||
target: svrm.ldm.modules.translator.img_to_triplane.ImgToTriplaneModel
|
||||
params:
|
||||
pos_emb_size: 64
|
||||
pos_emb_dim: 1024
|
||||
cam_cond_dim: 20
|
||||
n_heads: 16
|
||||
d_head: 64
|
||||
depth: 16
|
||||
context_dim: 768
|
||||
triplane_dim: 120
|
||||
use_fp16: true
|
||||
use_bf16: false
|
||||
upsample_time: 2
|
||||
|
||||
render_config:
|
||||
target: svrm.ldm.modules.rendering_neus.synthesizer.TriplaneSynthesizer
|
||||
params:
|
||||
triplane_dim: 120
|
||||
samples_per_ray: 128
|
||||
|
||||
|
32
svrm/configs/svrm.yaml
Normal file
@ -0,0 +1,32 @@
|
||||
model:
|
||||
base_learning_rate: 3.0e-05
|
||||
target: svrm.ldm.models.svrm.SVRMModel
|
||||
params:
|
||||
|
||||
img_encoder_config:
|
||||
target: svrm.ldm.modules.encoders.dinov2_mod.FrozenDinoV2ImageEmbedder
|
||||
params:
|
||||
version: dinov2_vitb14
|
||||
|
||||
img_to_triplane_config:
|
||||
target: svrm.ldm.modules.translator.img_to_triplane.ImgToTriplaneModel
|
||||
params:
|
||||
pos_emb_size: 64
|
||||
pos_emb_dim: 1024
|
||||
cam_cond_dim: 20
|
||||
n_heads: 16
|
||||
d_head: 64
|
||||
depth: 16
|
||||
context_dim: 768
|
||||
triplane_dim: 120
|
||||
use_fp16: true
|
||||
use_bf16: false
|
||||
upsample_time: 2
|
||||
|
||||
render_config:
|
||||
target: svrm.ldm.modules.rendering_neus.synthesizer.TriplaneSynthesizer
|
||||
params:
|
||||
triplane_dim: 120
|
||||
samples_per_ray: 128
|
||||
|
||||
|
BIN
svrm/ldm/.DS_Store
vendored
Normal file
263
svrm/ldm/models/svrm.py
Normal file
@ -0,0 +1,263 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import time
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
import itertools
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
try:
|
||||
import trimesh
|
||||
import mcubes
|
||||
import xatlas
|
||||
import open3d as o3d
|
||||
except:
|
||||
raise "failed to import 3d libraries "
|
||||
|
||||
from ..modules.rendering_neus.mesh import Mesh
|
||||
from ..modules.rendering_neus.rasterize import NVDiffRasterizerContext
|
||||
|
||||
from ..utils.ops import scale_tensor
|
||||
from ..util import count_params, instantiate_from_config
|
||||
from ..vis_util import render
|
||||
|
||||
|
||||
def unwrap_uv(v_pos, t_pos_idx):
|
||||
print("Using xatlas to perform UV unwrapping, may take a while ...")
|
||||
atlas = xatlas.Atlas()
|
||||
atlas.add_mesh(v_pos, t_pos_idx)
|
||||
atlas.generate(xatlas.ChartOptions(), xatlas.PackOptions())
|
||||
_, indices, uvs = atlas.get_mesh(0)
|
||||
indices = indices.astype(np.int64, casting="same_kind")
|
||||
return uvs, indices
|
||||
|
||||
|
||||
def uv_padding(image, hole_mask, uv_padding_size = 2):
|
||||
return cv2.inpaint(
|
||||
(image.detach().cpu().numpy() * 255).astype(np.uint8),
|
||||
(hole_mask.detach().cpu().numpy() * 255).astype(np.uint8),
|
||||
uv_padding_size,
|
||||
cv2.INPAINT_TELEA
|
||||
)
|
||||
|
||||
def refine_mesh(vtx_refine, faces_refine):
|
||||
mesh = o3d.geometry.TriangleMesh(
|
||||
vertices=o3d.utility.Vector3dVector(vtx_refine),
|
||||
triangles=o3d.utility.Vector3iVector(faces_refine))
|
||||
|
||||
mesh = mesh.remove_unreferenced_vertices()
|
||||
mesh = mesh.remove_duplicated_triangles()
|
||||
mesh = mesh.remove_duplicated_vertices()
|
||||
|
||||
voxel_size = max(mesh.get_max_bound() - mesh.get_min_bound())
|
||||
|
||||
mesh = mesh.simplify_vertex_clustering(
|
||||
voxel_size=0.007, # 0.005
|
||||
contraction=o3d.geometry.SimplificationContraction.Average)
|
||||
|
||||
mesh = mesh.filter_smooth_simple(number_of_iterations=2)
|
||||
|
||||
vtx_refine = np.asarray(mesh.vertices).astype(np.float32)
|
||||
faces_refine = np.asarray(mesh.triangles)
|
||||
return vtx_refine, faces_refine, mesh
|
||||
|
||||
|
||||
class SVRMModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_encoder_config,
|
||||
img_to_triplane_config,
|
||||
render_config,
|
||||
device = "cuda:0",
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.img_encoder = instantiate_from_config(img_encoder_config).half()
|
||||
self.img_to_triplane_decoder = instantiate_from_config(img_to_triplane_config).half()
|
||||
self.render = instantiate_from_config(render_config).half()
|
||||
self.device = device
|
||||
count_params(self, verbose=True)
|
||||
|
||||
@torch.no_grad()
|
||||
def export_mesh_with_uv(
|
||||
self,
|
||||
data,
|
||||
mesh_size: int = 384,
|
||||
ctx = None,
|
||||
context_type = 'cuda',
|
||||
texture_res = 1024,
|
||||
target_face_count = 10000,
|
||||
do_texture_mapping = True,
|
||||
out_dir = 'outputs/test'
|
||||
):
|
||||
"""
|
||||
color_type: 0 for ray texture, 1 for vertices texture
|
||||
"""
|
||||
st = time.time()
|
||||
here = {'device': self.device, 'dtype': torch.float16}
|
||||
input_view_image = data["input_view"].to(**here) # [b, m, c, h, w]
|
||||
input_view_cam = data["input_view_cam"].to(**here) # [b, m, 20]
|
||||
|
||||
batch_size, input_view_num, *_ = input_view_image.shape
|
||||
assert batch_size == 1, "batch size should be 1"
|
||||
|
||||
input_view_image = rearrange(input_view_image, 'b m c h w -> (b m) c h w')
|
||||
input_view_cam = rearrange(input_view_cam, 'b m d -> (b m) d')
|
||||
input_view_feat = self.img_encoder(input_view_image, input_view_cam)
|
||||
input_view_feat = rearrange(input_view_feat, '(b m) l d -> b (l m) d', m=input_view_num)
|
||||
|
||||
# -- decoder
|
||||
torch.cuda.empty_cache()
|
||||
triplane_gen = self.img_to_triplane_decoder(input_view_feat) # [b, 3, tri_dim, h, w]
|
||||
del input_view_feat
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# --- triplane nerf render
|
||||
|
||||
cur_triplane = triplane_gen[0:1]
|
||||
|
||||
aabb = torch.tensor([[-0.6, -0.6, -0.6], [0.6, 0.6, 0.6]]).unsqueeze(0).to(**here)
|
||||
grid_out = self.render.forward_grid(planes=cur_triplane, grid_size=mesh_size, aabb=aabb)
|
||||
|
||||
print(f"=====> LRM forward time: {time.time() - st}")
|
||||
st = time.time()
|
||||
|
||||
vtx, faces = mcubes.marching_cubes(0. - grid_out['sdf'].squeeze(0).squeeze(-1).cpu().float().numpy(), 0)
|
||||
|
||||
bbox = aabb[0].cpu().numpy()
|
||||
vtx = vtx / (mesh_size - 1)
|
||||
vtx = vtx * (bbox[1] - bbox[0]) + bbox[0]
|
||||
|
||||
# refine mesh
|
||||
vtx_refine, faces_refine, mesh = refine_mesh(vtx, faces)
|
||||
|
||||
# reduce faces
|
||||
if faces_refine.shape[0] > target_face_count:
|
||||
print(f"reduce face: {faces_refine.shape[0]} -> {target_face_count}")
|
||||
mesh = o3d.geometry.TriangleMesh(
|
||||
vertices = o3d.utility.Vector3dVector(vtx_refine),
|
||||
triangles = o3d.utility.Vector3iVector(faces_refine)
|
||||
)
|
||||
|
||||
# Function to simplify mesh using Quadric Error Metric Decimation by Garland and Heckbert
|
||||
mesh = mesh.simplify_quadric_decimation(target_face_count, boundary_weight=1.0)
|
||||
|
||||
mesh = Mesh(
|
||||
v_pos = torch.from_numpy(np.asarray(mesh.vertices)).to(self.device),
|
||||
t_pos_idx = torch.from_numpy(np.asarray(mesh.triangles)).to(self.device),
|
||||
v_rgb = torch.from_numpy(np.asarray(mesh.vertex_colors)).to(self.device)
|
||||
)
|
||||
vtx_refine = mesh.v_pos.cpu().numpy()
|
||||
faces_refine = mesh.t_pos_idx.cpu().numpy()
|
||||
|
||||
vtx_colors = self.render.forward_points(cur_triplane, torch.tensor(vtx_refine).unsqueeze(0).to(**here))
|
||||
vtx_colors = vtx_colors['rgb'].float().squeeze(0).cpu().numpy()
|
||||
|
||||
color_ratio = 0.8 # increase brightness
|
||||
with open(f'{out_dir}/mesh_with_colors.obj', 'w') as fid:
|
||||
verts = vtx_refine[:, [1,2,0]]
|
||||
for pidx, pp in enumerate(verts):
|
||||
color = vtx_colors[pidx]
|
||||
color = [color[0]**color_ratio, color[1]**color_ratio, color[2]**color_ratio]
|
||||
fid.write('v %f %f %f %f %f %f\n' % (pp[0], pp[1], pp[2], color[0], color[1], color[2]))
|
||||
for i, f in enumerate(faces_refine):
|
||||
f1 = f + 1
|
||||
fid.write('f %d %d %d\n' % (f1[0], f1[1], f1[2]))
|
||||
|
||||
mesh = trimesh.load_mesh(f'{out_dir}/mesh_with_colors.obj')
|
||||
print(f"=====> generate mesh with vertex shading time: {time.time() - st}")
|
||||
st = time.time()
|
||||
|
||||
if not do_texture_mapping:
|
||||
shutil.copy(f'{out_dir}/mesh_with_colors.obj', f'{out_dir}/mesh.obj')
|
||||
mesh.export(f'{out_dir}/mesh.glb', file_type='glb')
|
||||
return None
|
||||
|
||||
########## export texture ########
|
||||
st = time.time()
|
||||
|
||||
# uv unwrap
|
||||
vtx_tex, t_tex_idx = unwrap_uv(vtx_refine, faces_refine)
|
||||
vtx_refine = torch.from_numpy(vtx_refine).to(self.device)
|
||||
faces_refine = torch.from_numpy(faces_refine).to(self.device)
|
||||
t_tex_idx = torch.from_numpy(t_tex_idx).to(self.device)
|
||||
uv_clip = torch.from_numpy(vtx_tex * 2.0 - 1.0).to(self.device)
|
||||
|
||||
# rasterize
|
||||
ctx = NVDiffRasterizerContext(context_type, cur_triplane.device) if ctx is None else ctx
|
||||
rast = ctx.rasterize_one(
|
||||
torch.cat([
|
||||
uv_clip,
|
||||
torch.zeros_like(uv_clip[..., 0:1]),
|
||||
torch.ones_like(uv_clip[..., 0:1])
|
||||
], dim=-1),
|
||||
t_tex_idx,
|
||||
(texture_res, texture_res)
|
||||
)[0]
|
||||
hole_mask = ~(rast[:, :, 3] > 0)
|
||||
|
||||
# Interpolate world space position
|
||||
gb_pos = ctx.interpolate_one(vtx_refine, rast[None, ...], faces_refine)[0][0]
|
||||
with torch.no_grad():
|
||||
gb_mask_pos_scale = scale_tensor(gb_pos.unsqueeze(0).view(1, -1, 3), (-1, 1), (-1, 1))
|
||||
tex_map = self.render.forward_points(cur_triplane, gb_mask_pos_scale)['rgb']
|
||||
tex_map = tex_map.float().squeeze(0) # (0, 1)
|
||||
tex_map = tex_map.view((texture_res, texture_res, 3))
|
||||
img = uv_padding(tex_map, hole_mask)
|
||||
img = ((img/255.0) ** color_ratio) * 255 # increase brightness
|
||||
img = img.clip(0, 255).astype(np.uint8)
|
||||
|
||||
verts = vtx_refine.cpu().numpy()[:, [1,2,0]]
|
||||
faces = faces_refine.cpu().numpy()
|
||||
|
||||
with open(f'{out_dir}/texture.mtl', 'w') as fid:
|
||||
fid.write('newmtl material_0\n')
|
||||
fid.write("Ka 1.000 1.000 1.000\n")
|
||||
fid.write("Kd 1.000 1.000 1.000\n")
|
||||
fid.write("Ks 0.000 0.000 0.000\n")
|
||||
fid.write("d 1.0\n")
|
||||
fid.write("illum 2\n")
|
||||
fid.write(f'map_Kd texture.png\n')
|
||||
|
||||
with open(f'{out_dir}/mesh.obj', 'w') as fid:
|
||||
fid.write(f'mtllib texture.mtl\n')
|
||||
for pidx, pp in enumerate(verts):
|
||||
fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2]))
|
||||
for pidx, pp in enumerate(vtx_tex):
|
||||
fid.write('vt %f %f\n' % (pp[0], 1 - pp[1]))
|
||||
fid.write('usemtl material_0\n')
|
||||
for i, f in enumerate(faces):
|
||||
f1 = f + 1
|
||||
f2 = t_tex_idx[i] + 1
|
||||
fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2],))
|
||||
|
||||
cv2.imwrite(f'{out_dir}/texture.png', img[..., [2, 1, 0]])
|
||||
mesh = trimesh.load_mesh(f'{out_dir}/mesh.obj')
|
||||
mesh.export(f'{out_dir}/mesh.glb', file_type='glb')
|
||||
|
457
svrm/ldm/modules/attention.py
Normal file
@ -0,0 +1,457 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
import numpy as np
|
||||
|
||||
FLASH_IS_AVAILABLE = XFORMERS_IS_AVAILBLE = False
|
||||
try:
|
||||
from flash_attn import flash_attn_qkvpacked_func, flash_attn_func
|
||||
FLASH_IS_AVAILABLE = True
|
||||
except:
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except:
|
||||
pass
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
out = einsum('b i j, b j d -> b i d', attn, v) # [b*h, n, d]
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class FlashAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{heads} heads.")
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.dropout = dropout
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
context = default(context, x)
|
||||
h = self.heads
|
||||
dtype = torch.bfloat16 # torch.half
|
||||
q = self.to_q(x).to(dtype)
|
||||
k = self.to_k(context).to(dtype)
|
||||
v = self.to_v(context).to(dtype)
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q, k, v)) # q is [b, 3079, 16, 64]
|
||||
out = flash_attn_func(q, k, v, dropout_p=self.dropout, softmax_scale=None, causal=False, window_size=(-1, -1)) # out is same shape to q
|
||||
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
|
||||
return self.to_out(out.float())
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
|
||||
super().__init__()
|
||||
print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using "
|
||||
f"{heads} heads.")
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3)
|
||||
.reshape(b, t.shape[1], self.heads, self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b * self.heads, t.shape[1], self.dim_head)
|
||||
.contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, self.heads, out.shape[1], self.dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, out.shape[1], self.heads * self.dim_head)
|
||||
)
|
||||
return self.to_out(out)
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = Fp32LayerNorm(dim)
|
||||
self.norm2 = Fp32LayerNorm(dim)
|
||||
self.norm3 = Fp32LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
ATTENTION_MODES = {
|
||||
"softmax": CrossAttention, # vanilla attention
|
||||
"softmax-xformers": MemoryEfficientCrossAttention,
|
||||
"softmax-flash": FlashAttention
|
||||
}
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
|
||||
class Fp32LayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
class AdaNorm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(dim, 2 * dim, bias=True)
|
||||
)
|
||||
self.norm = Fp32LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, c): # x is fp32, c is fp16
|
||||
shift, scale = self.adaLN_modulation(c.float()).chunk(2, dim=1) # bf16
|
||||
x = modulate(self.norm(x), shift, scale) # fp32
|
||||
return x
|
||||
|
||||
|
||||
class BasicTransformerBlockLRM(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, \
|
||||
checkpoint=True):
|
||||
super().__init__()
|
||||
|
||||
attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
|
||||
attn_mode = "softmax-flash" if FLASH_IS_AVAILABLE else attn_mode
|
||||
assert attn_mode in ATTENTION_MODES
|
||||
attn_cls = ATTENTION_MODES[attn_mode]
|
||||
|
||||
self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, \
|
||||
context_dim=context_dim) # cross-attn
|
||||
self.attn2 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, \
|
||||
context_dim=None) # self-attn
|
||||
|
||||
self.norm1 = Fp32LayerNorm(dim)
|
||||
self.norm2 = Fp32LayerNorm(dim)
|
||||
self.norm3 = Fp32LayerNorm(dim)
|
||||
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None, cam_emb=None): # (torch.float32, torch.float32, torch.bfloat16)
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
|
||||
def _forward(self, x, context=None, cam_emb=None):
|
||||
|
||||
x = self.attn1(self.norm1(x), context=context) + x # cross-attn
|
||||
x = self.attn2(self.norm2(x), context=None) + x # self-attn
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
|
||||
return x
|
||||
|
||||
class ImgToTriplaneTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
def __init__(self, query_dim, n_heads, d_head, depth=1, dropout=0., context_dim=None, triplane_size=64):
|
||||
super().__init__()
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlockLRM(query_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)])
|
||||
|
||||
self.norm = Fp32LayerNorm(query_dim, eps=1e-6)
|
||||
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
# Initialize transformer layers:
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
if module.weight is not None:
|
||||
nn.init.constant_(module.weight, 1.0)
|
||||
self.apply(_basic_init)
|
||||
|
||||
def forward(self, x, context=None, cam_emb=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
|
||||
|
0
svrm/ldm/modules/encoders/__init__.py
Normal file
0
svrm/ldm/modules/encoders/dinov2/__init__.py
Normal file
0
svrm/ldm/modules/encoders/dinov2/hub/__init__.py
Normal file
156
svrm/ldm/modules/encoders/dinov2/hub/backbones.py
Normal file
@ -0,0 +1,156 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
||||
|
||||
|
||||
class Weights(Enum):
|
||||
LVD142M = "LVD142M"
|
||||
|
||||
|
||||
def _make_dinov2_model(
|
||||
*,
|
||||
arch_name: str = "vit_large",
|
||||
img_size: int = 518,
|
||||
patch_size: int = 14,
|
||||
init_values: float = 1.0,
|
||||
ffn_layer: str = "mlp",
|
||||
block_chunks: int = 0,
|
||||
num_register_tokens: int = 0,
|
||||
interpolate_antialias: bool = False,
|
||||
interpolate_offset: float = 0.1,
|
||||
pretrained: bool = True,
|
||||
weights: Union[Weights, str] = Weights.LVD142M,
|
||||
**kwargs,
|
||||
):
|
||||
from ..models import vision_transformer as vits
|
||||
|
||||
if isinstance(weights, str):
|
||||
try:
|
||||
weights = Weights[weights]
|
||||
except KeyError:
|
||||
raise AssertionError(f"Unsupported weights: {weights}")
|
||||
|
||||
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
||||
vit_kwargs = dict(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
init_values=init_values,
|
||||
ffn_layer=ffn_layer,
|
||||
block_chunks=block_chunks,
|
||||
num_register_tokens=num_register_tokens,
|
||||
interpolate_antialias=interpolate_antialias,
|
||||
interpolate_offset=interpolate_offset,
|
||||
)
|
||||
vit_kwargs.update(**kwargs)
|
||||
model = vits.__dict__[arch_name](**vit_kwargs)
|
||||
|
||||
if pretrained:
|
||||
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
||||
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
||||
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name="vit_giant2",
|
||||
ffn_layer="swiglufused",
|
||||
weights=weights,
|
||||
pretrained=pretrained,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name="vit_small",
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name="vit_base",
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name="vit_large",
|
||||
pretrained=pretrained,
|
||||
weights=weights,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name="vit_giant2",
|
||||
ffn_layer="swiglufused",
|
||||
weights=weights,
|
||||
pretrained=pretrained,
|
||||
num_register_tokens=4,
|
||||
interpolate_antialias=True,
|
||||
interpolate_offset=0.0,
|
||||
**kwargs,
|
||||
)
|
39
svrm/ldm/modules/encoders/dinov2/hub/utils.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import itertools
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
||||
|
||||
|
||||
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
||||
compact_arch_name = arch_name.replace("_", "")[:4]
|
||||
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
||||
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
||||
|
||||
|
||||
class CenterPadding(nn.Module):
|
||||
def __init__(self, multiple):
|
||||
super().__init__()
|
||||
self.multiple = multiple
|
||||
|
||||
def _get_pad(self, size):
|
||||
new_size = math.ceil(size / self.multiple) * self.multiple
|
||||
pad_size = new_size - size
|
||||
pad_size_left = pad_size // 2
|
||||
pad_size_right = pad_size - pad_size_left
|
||||
return pad_size_left, pad_size_right
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, x):
|
||||
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
||||
output = F.pad(x, pads)
|
||||
return output
|
11
svrm/ldm/modules/encoders/dinov2/layers/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .dino_head import DINOHead
|
||||
from .mlp import Mlp
|
||||
from .patch_embed import PatchEmbed
|
||||
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
||||
from .block import NestedTensorBlockMod
|
||||
from .attention import MemEffAttention
|
89
svrm/ldm/modules/encoders/dinov2/layers/attention.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
||||
try:
|
||||
if XFORMERS_ENABLED:
|
||||
from xformers.ops import memory_efficient_attention, unbind
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
warnings.warn("xFormers is available (Attention)")
|
||||
else:
|
||||
warnings.warn("xFormers is disabled (Attention)")
|
||||
raise ImportError
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
warnings.warn("xFormers is not available (Attention)")
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class MemEffAttention(Attention):
|
||||
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
||||
if not XFORMERS_AVAILABLE:
|
||||
if attn_bias is not None:
|
||||
raise AssertionError("xFormers is required for using nested tensors")
|
||||
return super().forward(x)
|
||||
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
|
||||
q, k, v = unbind(qkv, 2)
|
||||
|
||||
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
||||
x = x.reshape([B, N, C])
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
269
svrm/ldm/modules/encoders/dinov2/layers/block.py
Normal file
@ -0,0 +1,269 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
import os
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Callable, List, Any, Tuple, Dict
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
|
||||
from .attention import Attention, MemEffAttention
|
||||
from .drop_path import DropPath
|
||||
from .layer_scale import LayerScale
|
||||
from .mlp import Mlp
|
||||
|
||||
from ....attention import AdaNorm
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
||||
try:
|
||||
if XFORMERS_ENABLED:
|
||||
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
warnings.warn("xFormers is available (Block)")
|
||||
else:
|
||||
warnings.warn("xFormers is disabled (Block)")
|
||||
raise ImportError
|
||||
except ImportError:
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
warnings.warn("xFormers is not available (Block)")
|
||||
|
||||
|
||||
class BlockMod(nn.Module):
|
||||
'''
|
||||
using Modified Block, see below
|
||||
'''
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = False,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = AdaNorm,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_class(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
bias=ffn_bias,
|
||||
)
|
||||
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def forward(self, x: Tensor, cam_emb: Tensor) -> Tensor:
|
||||
def attn_residual_func(x: Tensor, cam_emb: Tensor = None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x, cam_emb)))
|
||||
|
||||
def ffn_residual_func(x: Tensor, cam_emb: Tensor = None) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x, cam_emb)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.1:
|
||||
# the overhead is compensated only for a drop path rate larger than 0.1
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x,
|
||||
residual_func=attn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
)
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x,
|
||||
residual_func=ffn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
)
|
||||
elif self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x, cam_emb))
|
||||
x = x + self.drop_path1(ffn_residual_func(x, cam_emb)) # FIXME: drop_path2
|
||||
else:
|
||||
x = x + attn_residual_func(x, cam_emb)
|
||||
x = x + ffn_residual_func(x, cam_emb)
|
||||
return x
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_depth(
|
||||
x: Tensor,
|
||||
residual_func: Callable[[Tensor], Tensor],
|
||||
sample_drop_ratio: float = 0.0,
|
||||
) -> Tensor:
|
||||
# drop_add_residual_stochastic_depth_list
|
||||
|
||||
# 1) extract subset using permutation
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
x_subset = x[brange]
|
||||
|
||||
# 2) apply residual_func to get residual
|
||||
residual = residual_func(x_subset)
|
||||
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
|
||||
# 3) add the residual
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
return x_plus_residual.view_as(x)
|
||||
|
||||
|
||||
def get_branges_scales(x, sample_drop_ratio=0.0):
|
||||
# get_branges_scales
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
return brange, residual_scale_factor
|
||||
|
||||
|
||||
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
||||
# add residuals
|
||||
if scaling_vector is None:
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
||||
else:
|
||||
x_plus_residual = scaled_index_add(
|
||||
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
||||
)
|
||||
return x_plus_residual
|
||||
|
||||
|
||||
attn_bias_cache: Dict[Tuple, Any] = {}
|
||||
|
||||
|
||||
def get_attn_bias_and_cat(x_list, branges=None):
|
||||
"""
|
||||
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
||||
"""
|
||||
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
||||
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
||||
if all_shapes not in attn_bias_cache.keys():
|
||||
seqlens = []
|
||||
for b, x in zip(batch_sizes, x_list):
|
||||
for _ in range(b):
|
||||
seqlens.append(x.shape[1])
|
||||
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
||||
attn_bias._batch_sizes = batch_sizes
|
||||
attn_bias_cache[all_shapes] = attn_bias
|
||||
|
||||
if branges is not None:
|
||||
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
||||
else:
|
||||
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
||||
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
||||
|
||||
return attn_bias_cache[all_shapes], cat_tensors
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_list(
|
||||
x_list: List[Tensor],
|
||||
residual_func: Callable[[Tensor, Any], Tensor],
|
||||
sample_drop_ratio: float = 0.0,
|
||||
scaling_vector=None,
|
||||
) -> Tensor:
|
||||
# 1) generate random set of indices for dropping samples in the batch
|
||||
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
||||
branges = [s[0] for s in branges_scales]
|
||||
residual_scale_factors = [s[1] for s in branges_scales]
|
||||
|
||||
# 2) get attention bias and index+concat the tensors
|
||||
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
||||
|
||||
# 3) apply residual_func to get residual, and split the result
|
||||
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
||||
|
||||
outputs = []
|
||||
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
||||
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
||||
return outputs
|
||||
|
||||
|
||||
class NestedTensorBlockMod(BlockMod):
|
||||
def forward_nested(self, x_list: List[Tensor], cam_emb_list: List[Tensor]) -> List[Tensor]:
|
||||
"""
|
||||
x_list contains a list of tensors to nest together and run
|
||||
"""
|
||||
assert isinstance(self.attn, MemEffAttention)
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
|
||||
def attn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.attn(self.norm1(x, cam_emb), attn_bias=attn_bias)
|
||||
|
||||
def ffn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.mlp(self.norm2(x, cam_emb))
|
||||
|
||||
x_list = drop_add_residual_stochastic_list(
|
||||
x_list,
|
||||
residual_func=attn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
||||
)
|
||||
x_list = drop_add_residual_stochastic_list(
|
||||
x_list,
|
||||
residual_func=ffn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
||||
)
|
||||
return x_list
|
||||
else:
|
||||
|
||||
def attn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x, cam_emb), attn_bias=attn_bias))
|
||||
|
||||
def ffn_residual_func(x: Tensor, cam_emb: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x, cam_emb)))
|
||||
|
||||
attn_bias, x = get_attn_bias_and_cat(x_list)
|
||||
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
||||
x = x + ffn_residual_func(x)
|
||||
return attn_bias.split(x)
|
||||
|
||||
def forward(self, x_or_x_list, cam_emb_or_cam_emb_list):
|
||||
if isinstance(x_or_x_list, Tensor) and isinstance(cam_emb_or_cam_emb_list, Tensor) :
|
||||
return super().forward(x_or_x_list, cam_emb_or_cam_emb_list)
|
||||
elif isinstance(x_or_x_list, list) and isinstance(cam_emb_or_cam_emb_list, list):
|
||||
if not XFORMERS_AVAILABLE:
|
||||
raise AssertionError("xFormers is required for using nested tensors")
|
||||
return self.forward_nested(x_or_x_list, cam_emb_or_cam_emb_list)
|
||||
else:
|
||||
raise AssertionError
|
58
svrm/ldm/modules/encoders/dinov2/layers/dino_head.py
Normal file
@ -0,0 +1,58 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
class DINOHead(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
use_bn=False,
|
||||
nlayers=3,
|
||||
hidden_dim=2048,
|
||||
bottleneck_dim=256,
|
||||
mlp_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
nlayers = max(nlayers, 1)
|
||||
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
||||
self.apply(self._init_weights)
|
||||
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
||||
self.last_layer.weight_g.data.fill_(1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.mlp(x)
|
||||
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
||||
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
||||
x = self.last_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
||||
if nlayers == 1:
|
||||
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
||||
else:
|
||||
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
||||
if use_bn:
|
||||
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
for _ in range(nlayers - 2):
|
||||
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
||||
if use_bn:
|
||||
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
||||
return nn.Sequential(*layers)
|
34
svrm/ldm/modules/encoders/dinov2/layers/drop_path.py
Normal file
@ -0,0 +1,34 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
||||
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0:
|
||||
random_tensor.div_(keep_prob)
|
||||
output = x * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
27
svrm/ldm/modules/encoders/dinov2/layers/layer_scale.py
Normal file
@ -0,0 +1,27 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: Union[float, Tensor] = 1e-5,
|
||||
inplace: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
40
svrm/ldm/modules/encoders/dinov2/layers/mlp.py
Normal file
@ -0,0 +1,40 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
||||
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
88
svrm/ldm/modules/encoders/dinov2/layers/patch_embed.py
Normal file
@ -0,0 +1,88 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
from torch import Tensor
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def make_2tuple(x):
|
||||
if isinstance(x, tuple):
|
||||
assert len(x) == 2
|
||||
return x
|
||||
|
||||
assert isinstance(x, int)
|
||||
return (x, x)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
||||
|
||||
Args:
|
||||
img_size: Image size.
|
||||
patch_size: Patch token size.
|
||||
in_chans: Number of input image channels.
|
||||
embed_dim: Number of linear projection output channels.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten_embedding: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
image_HW = make_2tuple(img_size)
|
||||
patch_HW = make_2tuple(patch_size)
|
||||
patch_grid_size = (
|
||||
image_HW[0] // patch_HW[0],
|
||||
image_HW[1] // patch_HW[1],
|
||||
)
|
||||
|
||||
self.img_size = image_HW
|
||||
self.patch_size = patch_HW
|
||||
self.patches_resolution = patch_grid_size
|
||||
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.flatten_embedding = flatten_embedding
|
||||
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
_, _, H, W = x.shape
|
||||
patch_H, patch_W = self.patch_size
|
||||
|
||||
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
||||
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
||||
|
||||
x = self.proj(x) # B C H W
|
||||
H, W = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2) # B HW C
|
||||
x = self.norm(x)
|
||||
if not self.flatten_embedding:
|
||||
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
||||
return x
|
||||
|
||||
def flops(self) -> float:
|
||||
Ho, Wo = self.patches_resolution
|
||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
||||
if self.norm is not None:
|
||||
flops += Ho * Wo * self.embed_dim
|
||||
return flops
|
72
svrm/ldm/modules/encoders/dinov2/layers/swiglu_ffn.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional
|
||||
import warnings
|
||||
|
||||
from torch import Tensor, nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
||||
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
hidden = F.silu(x1) * x2
|
||||
return self.w3(hidden)
|
||||
|
||||
|
||||
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
||||
try:
|
||||
if XFORMERS_ENABLED:
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
warnings.warn("xFormers is available (SwiGLU)")
|
||||
else:
|
||||
warnings.warn("xFormers is disabled (SwiGLU)")
|
||||
raise ImportError
|
||||
except ImportError:
|
||||
SwiGLU = SwiGLUFFN
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
warnings.warn("xFormers is not available (SwiGLU)")
|
||||
|
||||
|
||||
class SwiGLUFFNFused(SwiGLU):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
super().__init__(
|
||||
in_features=in_features,
|
||||
hidden_features=hidden_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
)
|
43
svrm/ldm/modules/encoders/dinov2/models/__init__.py
Normal file
@ -0,0 +1,43 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from . import vision_transformer as vits
|
||||
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def build_model(args, only_teacher=False, img_size=224):
|
||||
args.arch = args.arch.removesuffix("_memeff")
|
||||
if "vit" in args.arch:
|
||||
vit_kwargs = dict(
|
||||
img_size=img_size,
|
||||
patch_size=args.patch_size,
|
||||
init_values=args.layerscale,
|
||||
ffn_layer=args.ffn_layer,
|
||||
block_chunks=args.block_chunks,
|
||||
qkv_bias=args.qkv_bias,
|
||||
proj_bias=args.proj_bias,
|
||||
ffn_bias=args.ffn_bias,
|
||||
num_register_tokens=args.num_register_tokens,
|
||||
interpolate_offset=args.interpolate_offset,
|
||||
interpolate_antialias=args.interpolate_antialias,
|
||||
)
|
||||
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
||||
if only_teacher:
|
||||
return teacher, teacher.embed_dim
|
||||
student = vits.__dict__[args.arch](
|
||||
**vit_kwargs,
|
||||
drop_path_rate=args.drop_path_rate,
|
||||
drop_path_uniform=args.drop_path_uniform,
|
||||
)
|
||||
embed_dim = student.embed_dim
|
||||
return student, teacher, embed_dim
|
||||
|
||||
|
||||
def build_model_from_cfg(cfg, only_teacher=False):
|
||||
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
415
svrm/ldm/modules/encoders/dinov2/models/vision_transformer.py
Normal file
@ -0,0 +1,415 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
import logging
|
||||
from typing import Sequence, Tuple, Union, Callable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlockMod as BlockMod
|
||||
from ....attention import AdaNorm
|
||||
|
||||
logger = logging.getLogger("dinov2")
|
||||
|
||||
|
||||
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = ".".join((name, child_name)) if name else child_name
|
||||
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
class BlockChunk(nn.ModuleList):
|
||||
def forward(self, x):
|
||||
for b in self:
|
||||
x = b(x)
|
||||
return x
|
||||
|
||||
|
||||
class DinoVisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
ffn_bias=True,
|
||||
proj_bias=True,
|
||||
drop_path_rate=0.0,
|
||||
drop_path_uniform=False,
|
||||
init_values=None, # for layerscale: None or 0 => no layerscale
|
||||
embed_layer=PatchEmbed,
|
||||
act_layer=nn.GELU,
|
||||
block_fn=BlockMod,
|
||||
ffn_layer="mlp",
|
||||
block_chunks=1,
|
||||
num_register_tokens=0,
|
||||
interpolate_antialias=False,
|
||||
interpolate_offset=0.1,
|
||||
pos_emb_dim=768,
|
||||
cam_cond_dim=20
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
proj_bias (bool): enable bias for proj in attn if True
|
||||
ffn_bias (bool): enable bias for ffn if True
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
drop_path_uniform (bool): apply uniform drop rate across blocks
|
||||
weight_init (str): weight init scheme
|
||||
init_values (float): layer-scale init values
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
act_layer (nn.Module): MLP activation layer
|
||||
block_fn (nn.Module): transformer block class
|
||||
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
||||
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
||||
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
||||
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
||||
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
||||
"""
|
||||
super().__init__()
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 1
|
||||
self.n_blocks = depth
|
||||
self.num_heads = num_heads
|
||||
self.patch_size = patch_size
|
||||
self.num_register_tokens = num_register_tokens
|
||||
self.interpolate_antialias = interpolate_antialias
|
||||
self.interpolate_offset = interpolate_offset
|
||||
|
||||
|
||||
norm_layer = AdaNorm
|
||||
self.cam_embed = nn.Sequential(
|
||||
nn.Linear(cam_cond_dim, pos_emb_dim, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(pos_emb_dim, pos_emb_dim, bias=True))
|
||||
|
||||
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||
assert num_register_tokens >= 0
|
||||
self.register_tokens = (
|
||||
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
||||
)
|
||||
|
||||
if drop_path_uniform is True:
|
||||
dpr = [drop_path_rate] * depth
|
||||
else:
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
||||
if ffn_layer == "mlp":
|
||||
logger.info("using MLP layer as FFN")
|
||||
ffn_layer = Mlp
|
||||
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
||||
logger.info("using SwiGLU layer as FFN")
|
||||
ffn_layer = SwiGLUFFNFused
|
||||
elif ffn_layer == "identity":
|
||||
logger.info("using Identity layer as FFN")
|
||||
|
||||
def f(*args, **kwargs):
|
||||
return nn.Identity()
|
||||
|
||||
ffn_layer = f
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
blocks_list = [
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
ffn_bias=ffn_bias,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
ffn_layer=ffn_layer,
|
||||
init_values=init_values,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
if block_chunks > 0:
|
||||
self.chunked_blocks = True
|
||||
chunked_blocks = []
|
||||
chunksize = depth // block_chunks
|
||||
for i in range(0, depth, chunksize):
|
||||
# this is to keep the block index consistent if we chunk the block list
|
||||
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
||||
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
||||
else:
|
||||
self.chunked_blocks = False
|
||||
self.blocks = nn.ModuleList(blocks_list)
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Identity()
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
if self.register_tokens is not None:
|
||||
nn.init.normal_(self.register_tokens, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.pos_embed.shape[1] - 1
|
||||
if npatch == N and w == h:
|
||||
return self.pos_embed
|
||||
pos_embed = self.pos_embed.float()
|
||||
class_pos_embed = pos_embed[:, 0]
|
||||
patch_pos_embed = pos_embed[:, 1:]
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
||||
|
||||
sqrt_N = math.sqrt(N)
|
||||
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(sx, sy),
|
||||
mode="bicubic",
|
||||
antialias=self.interpolate_antialias,
|
||||
)
|
||||
|
||||
assert int(w0) == patch_pos_embed.shape[-2]
|
||||
assert int(h0) == patch_pos_embed.shape[-1]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
||||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
x = self.patch_embed(x)
|
||||
if masks is not None:
|
||||
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
|
||||
if self.register_tokens is not None:
|
||||
x = torch.cat(
|
||||
(
|
||||
x[:, :1],
|
||||
self.register_tokens.expand(x.shape[0], -1, -1),
|
||||
x[:, 1:],
|
||||
),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
return x
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
all_x = x
|
||||
output = []
|
||||
for x, masks in zip(all_x, masks_list):
|
||||
x_norm = self.norm(x)
|
||||
output.append(
|
||||
{
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_features_list_with_camera(self, x_list, cam_cond_list, masks_list):
|
||||
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
||||
cam_emb = [self.cam_embed(cam_cond) for cam_cond in cam_cond_list]
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cam_emb)
|
||||
|
||||
all_x = x
|
||||
all_cam_emb = cam_emb
|
||||
output = []
|
||||
for x, cam_emb, masks in zip(all_x, all_cam_emb, masks_list):
|
||||
x_norm = self.norm(x, cam_emb)
|
||||
output.append(
|
||||
{
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_features(self, x, masks=None):
|
||||
if isinstance(x, list):
|
||||
return self.forward_features_list(x, masks)
|
||||
|
||||
x = self.prepare_tokens_with_masks(x, masks)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x_norm = self.norm(x)
|
||||
return {
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
|
||||
def forward_features_with_camera(self, x, cam_cond, masks=None):
|
||||
if isinstance(x, list):
|
||||
return self.forward_features_list(x, cam_cond, masks)
|
||||
cam_emb = self.cam_embed(cam_cond)
|
||||
x = self.prepare_tokens_with_masks(x, masks)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, cam_emb)
|
||||
x_norm = self.norm(x, cam_emb)
|
||||
return {
|
||||
"x_norm_clstoken": x_norm[:, 0],
|
||||
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
||||
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
||||
"x_prenorm": x,
|
||||
"masks": masks,
|
||||
}
|
||||
|
||||
def _get_inter_layers_not_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
output, total_block_len = [], len(self.blocks)
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||
return output
|
||||
|
||||
def _get_intermediate_layers_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
||||
for block_chunk in self.blocks:
|
||||
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
i += 1
|
||||
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
||||
return output
|
||||
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
||||
reshape: bool = False,
|
||||
return_class_token: bool = False,
|
||||
norm=True,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
if self.chunked_blocks:
|
||||
outputs = self._get_intermediate_layers_chunked(x, n)
|
||||
else:
|
||||
outputs = self._get_inter_layers_not_chunked(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
class_tokens = [out[:, 0] for out in outputs]
|
||||
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
||||
if reshape:
|
||||
B, _, w, h = x.shape
|
||||
outputs = [
|
||||
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
if return_class_token:
|
||||
return tuple(zip(outputs, class_tokens))
|
||||
return tuple(outputs)
|
||||
|
||||
def forward(self, *args, is_training=False, **kwargs):
|
||||
|
||||
ret = self.forward_features_with_camera(*args, **kwargs)
|
||||
|
||||
if is_training:
|
||||
return ret
|
||||
else:
|
||||
return self.head(ret["x_norm_clstoken"])
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
||||
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=0.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, AdaNorm):
|
||||
nn.init.constant_(module.adaLN_modulation[-1].weight, 0)
|
||||
nn.init.constant_(module.adaLN_modulation[-1].bias, 0)
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
if module.weight is not None:
|
||||
nn.init.constant_(module.weight, 1.0)
|
||||
|
||||
|
||||
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(BlockMod, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(BlockMod, attn_class=MemEffAttention),
|
||||
num_register_tokens=num_register_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
67
svrm/ldm/modules/encoders/dinov2_mod.py
Normal file
@ -0,0 +1,67 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
#
|
||||
# This source code is licensed under the Apache License, Version 2.0
|
||||
# found in the LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
from functools import partial
|
||||
import math
|
||||
import logging
|
||||
from typing import Sequence, Tuple, Union, Callable
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from .dinov2.hub.backbones import dinov2_vitb14
|
||||
|
||||
class FrozenDinoV2ImageEmbedder(nn.Module):
|
||||
"""
|
||||
Uses the dinov2 image encoder with camera modulation.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
version='dinov2_vitb14',
|
||||
ckpt_path=None,
|
||||
lrm_mode='plain_lrm',
|
||||
):
|
||||
super().__init__()
|
||||
self.lrm_mode = lrm_mode
|
||||
assert version in ['dinov2_vitb14', 'dinov2_vits14', 'dinov2_vitl14', 'dinov2_vitg14']
|
||||
|
||||
|
||||
self.model = dinov2_vitb14(pretrained=False)
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.load_pretrained(ckpt_path)
|
||||
else:
|
||||
print('None pretrained model for dinov2 encoder ...')
|
||||
|
||||
|
||||
def load_pretrained(self, ckpt_path):
|
||||
print('Loading dinov2 encoder ...')
|
||||
orig_state_dict = torch.load(ckpt_path, map_location='cpu')
|
||||
try:
|
||||
ret = self.model.load_state_dict(orig_state_dict, strict=False)
|
||||
print(ret)
|
||||
print('Successfully loaded orig state dict')
|
||||
except:
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in orig_state_dict['state_dict'].items():
|
||||
if 'img_encoder' in k:
|
||||
new_state_dict[k.replace('img_encoder.model.', '')] = v
|
||||
ret = self.model.load_state_dict(new_state_dict, strict=False)
|
||||
print(ret)
|
||||
print('Successfully loaded new state dict')
|
||||
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
ret = self.model.forward_features_with_camera(x, *args, **kwargs)
|
||||
output = torch.cat([ret['x_norm_clstoken'].unsqueeze(1), ret['x_norm_patchtokens']], dim=1)
|
||||
return output
|
15
svrm/ldm/modules/rendering_neus/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright (c) 2023, Zexin He
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Empty
|
311
svrm/ldm/modules/rendering_neus/mesh.py
Normal file
@ -0,0 +1,311 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...utils.typing import *
|
||||
|
||||
def dot(x, y):
|
||||
return torch.sum(x * y, -1, keepdim=True)
|
||||
|
||||
class Mesh:
|
||||
def __init__(
|
||||
self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], v_rgb: Integer[Tensor, "Nf 3"], **kwargs
|
||||
) -> None:
|
||||
self.v_pos: Float[Tensor, "Nv 3"] = v_pos
|
||||
self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx
|
||||
self.v_rgb: Optional[Float[Tensor, "Nv 3"]] = v_rgb
|
||||
|
||||
self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None
|
||||
self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None
|
||||
self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None
|
||||
self._t_tex_idx: Optional[Float[Tensor, "Nf 3"]] = None
|
||||
# self._v_rgb: Optional[Float[Tensor, "Nv 3"]] = None
|
||||
self._edges: Optional[Integer[Tensor, "Ne 2"]] = None
|
||||
self.extras: Dict[str, Any] = {}
|
||||
for k, v in kwargs.items():
|
||||
self.add_extra(k, v)
|
||||
|
||||
def add_extra(self, k, v) -> None:
|
||||
self.extras[k] = v
|
||||
|
||||
def remove_outlier(self, outlier_n_faces_threshold: Union[int, float]) -> Mesh:
|
||||
if self.requires_grad:
|
||||
print("Mesh is differentiable, not removing outliers")
|
||||
return self
|
||||
|
||||
# use trimesh to first split the mesh into connected components
|
||||
# then remove the components with less than n_face_threshold faces
|
||||
import trimesh
|
||||
|
||||
# construct a trimesh object
|
||||
mesh = trimesh.Trimesh(
|
||||
vertices=self.v_pos.detach().cpu().numpy(),
|
||||
faces=self.t_pos_idx.detach().cpu().numpy(),
|
||||
)
|
||||
|
||||
# split the mesh into connected components
|
||||
components = mesh.split(only_watertight=False)
|
||||
# log the number of faces in each component
|
||||
print(
|
||||
"Mesh has {} components, with faces: {}".format(
|
||||
len(components), [c.faces.shape[0] for c in components]
|
||||
)
|
||||
)
|
||||
|
||||
n_faces_threshold: int
|
||||
if isinstance(outlier_n_faces_threshold, float):
|
||||
# set the threshold to the number of faces in the largest component multiplied by outlier_n_faces_threshold
|
||||
n_faces_threshold = int(
|
||||
max([c.faces.shape[0] for c in components]) * outlier_n_faces_threshold
|
||||
)
|
||||
else:
|
||||
# set the threshold directly to outlier_n_faces_threshold
|
||||
n_faces_threshold = outlier_n_faces_threshold
|
||||
|
||||
# log the threshold
|
||||
print(
|
||||
"Removing components with less than {} faces".format(n_faces_threshold)
|
||||
)
|
||||
|
||||
# remove the components with less than n_face_threshold faces
|
||||
components = [c for c in components if c.faces.shape[0] >= n_faces_threshold]
|
||||
|
||||
# log the number of faces in each component after removing outliers
|
||||
print(
|
||||
"Mesh has {} components after removing outliers, with faces: {}".format(
|
||||
len(components), [c.faces.shape[0] for c in components]
|
||||
)
|
||||
)
|
||||
# merge the components
|
||||
mesh = trimesh.util.concatenate(components)
|
||||
|
||||
# convert back to our mesh format
|
||||
v_pos = torch.from_numpy(mesh.vertices).to(self.v_pos)
|
||||
t_pos_idx = torch.from_numpy(mesh.faces).to(self.t_pos_idx)
|
||||
|
||||
clean_mesh = Mesh(v_pos, t_pos_idx)
|
||||
# keep the extras unchanged
|
||||
|
||||
if len(self.extras) > 0:
|
||||
clean_mesh.extras = self.extras
|
||||
print(
|
||||
f"The following extra attributes are inherited from the original mesh unchanged: {list(self.extras.keys())}"
|
||||
)
|
||||
return clean_mesh
|
||||
|
||||
@property
|
||||
def requires_grad(self):
|
||||
return self.v_pos.requires_grad
|
||||
|
||||
@property
|
||||
def v_nrm(self):
|
||||
if self._v_nrm is None:
|
||||
self._v_nrm = self._compute_vertex_normal()
|
||||
return self._v_nrm
|
||||
|
||||
@property
|
||||
def v_tng(self):
|
||||
if self._v_tng is None:
|
||||
self._v_tng = self._compute_vertex_tangent()
|
||||
return self._v_tng
|
||||
|
||||
@property
|
||||
def v_tex(self):
|
||||
if self._v_tex is None:
|
||||
self._v_tex, self._t_tex_idx = self._unwrap_uv()
|
||||
return self._v_tex
|
||||
|
||||
@property
|
||||
def t_tex_idx(self):
|
||||
if self._t_tex_idx is None:
|
||||
self._v_tex, self._t_tex_idx = self._unwrap_uv()
|
||||
return self._t_tex_idx
|
||||
|
||||
# @property
|
||||
# def v_rgb(self):
|
||||
# return self._v_rgb
|
||||
|
||||
@property
|
||||
def edges(self):
|
||||
if self._edges is None:
|
||||
self._edges = self._compute_edges()
|
||||
return self._edges
|
||||
|
||||
def _compute_vertex_normal(self):
|
||||
i0 = self.t_pos_idx[:, 0]
|
||||
i1 = self.t_pos_idx[:, 1]
|
||||
i2 = self.t_pos_idx[:, 2]
|
||||
|
||||
v0 = self.v_pos[i0, :]
|
||||
v1 = self.v_pos[i1, :]
|
||||
v2 = self.v_pos[i2, :]
|
||||
|
||||
face_normals = torch.cross(v1 - v0, v2 - v0)
|
||||
|
||||
# Splat face normals to vertices
|
||||
v_nrm = torch.zeros_like(self.v_pos)
|
||||
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals)
|
||||
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals)
|
||||
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals)
|
||||
|
||||
# Normalize, replace zero (degenerated) normals with some default value
|
||||
v_nrm = torch.where(
|
||||
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm)
|
||||
)
|
||||
v_nrm = F.normalize(v_nrm, dim=1)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(v_nrm))
|
||||
|
||||
return v_nrm
|
||||
|
||||
def _compute_vertex_tangent(self):
|
||||
vn_idx = [None] * 3
|
||||
pos = [None] * 3
|
||||
tex = [None] * 3
|
||||
for i in range(0, 3):
|
||||
pos[i] = self.v_pos[self.t_pos_idx[:, i]]
|
||||
tex[i] = self.v_tex[self.t_tex_idx[:, i]]
|
||||
# t_nrm_idx is always the same as t_pos_idx
|
||||
vn_idx[i] = self.t_pos_idx[:, i]
|
||||
|
||||
tangents = torch.zeros_like(self.v_nrm)
|
||||
tansum = torch.zeros_like(self.v_nrm)
|
||||
|
||||
# Compute tangent space for each triangle
|
||||
uve1 = tex[1] - tex[0]
|
||||
uve2 = tex[2] - tex[0]
|
||||
pe1 = pos[1] - pos[0]
|
||||
pe2 = pos[2] - pos[0]
|
||||
|
||||
nom = pe1 * uve2[..., 1:2] - pe2 * uve1[..., 1:2]
|
||||
denom = uve1[..., 0:1] * uve2[..., 1:2] - uve1[..., 1:2] * uve2[..., 0:1]
|
||||
|
||||
# Avoid division by zero for degenerated texture coordinates
|
||||
tang = nom / torch.where(
|
||||
denom > 0.0, torch.clamp(denom, min=1e-6), torch.clamp(denom, max=-1e-6)
|
||||
)
|
||||
|
||||
# Update all 3 vertices
|
||||
for i in range(0, 3):
|
||||
idx = vn_idx[i][:, None].repeat(1, 3)
|
||||
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang
|
||||
tansum.scatter_add_(
|
||||
0, idx, torch.ones_like(tang)
|
||||
) # tansum[n_i] = tansum[n_i] + 1
|
||||
tangents = tangents / tansum
|
||||
|
||||
# Normalize and make sure tangent is perpendicular to normal
|
||||
tangents = F.normalize(tangents, dim=1)
|
||||
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm)
|
||||
|
||||
if torch.is_anomaly_enabled():
|
||||
assert torch.all(torch.isfinite(tangents))
|
||||
|
||||
return tangents
|
||||
|
||||
def _unwrap_uv(
|
||||
self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
|
||||
):
|
||||
print("Using xatlas to perform UV unwrapping, may take a while ...")
|
||||
|
||||
import xatlas
|
||||
|
||||
atlas = xatlas.Atlas()
|
||||
atlas.add_mesh(
|
||||
self.v_pos.detach().cpu().numpy(),
|
||||
self.t_pos_idx.cpu().numpy(),
|
||||
)
|
||||
co = xatlas.ChartOptions()
|
||||
po = xatlas.PackOptions()
|
||||
for k, v in xatlas_chart_options.items():
|
||||
setattr(co, k, v)
|
||||
for k, v in xatlas_pack_options.items():
|
||||
setattr(po, k, v)
|
||||
atlas.generate(co, po)
|
||||
vmapping, indices, uvs = atlas.get_mesh(0)
|
||||
vmapping = (
|
||||
torch.from_numpy(
|
||||
vmapping.astype(np.uint64, casting="same_kind").view(np.int64)
|
||||
)
|
||||
.to(self.v_pos.device)
|
||||
.long()
|
||||
)
|
||||
uvs = torch.from_numpy(uvs).to(self.v_pos.device).float()
|
||||
indices = (
|
||||
torch.from_numpy(
|
||||
indices.astype(np.uint64, casting="same_kind").view(np.int64)
|
||||
)
|
||||
.to(self.v_pos.device)
|
||||
.long()
|
||||
)
|
||||
return uvs, indices
|
||||
|
||||
def unwrap_uv(
|
||||
self, xatlas_chart_options: dict = {}, xatlas_pack_options: dict = {}
|
||||
):
|
||||
self._v_tex, self._t_tex_idx = self._unwrap_uv(
|
||||
xatlas_chart_options, xatlas_pack_options
|
||||
)
|
||||
|
||||
def set_vertex_color(self, v_rgb):
|
||||
assert v_rgb.shape[0] == self.v_pos.shape[0]
|
||||
self._v_rgb = v_rgb
|
||||
|
||||
def _compute_edges(self):
|
||||
# Compute edges
|
||||
edges = torch.cat(
|
||||
[
|
||||
self.t_pos_idx[:, [0, 1]],
|
||||
self.t_pos_idx[:, [1, 2]],
|
||||
self.t_pos_idx[:, [2, 0]],
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
edges = edges.sort()[0]
|
||||
edges = torch.unique(edges, dim=0)
|
||||
return edges
|
||||
|
||||
def normal_consistency(self) -> Float[Tensor, ""]:
|
||||
edge_nrm: Float[Tensor, "Ne 2 3"] = self.v_nrm[self.edges]
|
||||
nc = (
|
||||
1.0 - torch.cosine_similarity(edge_nrm[:, 0], edge_nrm[:, 1], dim=-1)
|
||||
).mean()
|
||||
return nc
|
||||
|
||||
def _laplacian_uniform(self):
|
||||
# from stable-dreamfusion
|
||||
# https://github.com/ashawkey/stable-dreamfusion/blob/8fb3613e9e4cd1ded1066b46e80ca801dfb9fd06/nerf/renderer.py#L224
|
||||
verts, faces = self.v_pos, self.t_pos_idx
|
||||
|
||||
V = verts.shape[0]
|
||||
F = faces.shape[0]
|
||||
|
||||
# Neighbor indices
|
||||
ii = faces[:, [1, 2, 0]].flatten()
|
||||
jj = faces[:, [2, 0, 1]].flatten()
|
||||
adj = torch.stack([torch.cat([ii, jj]), torch.cat([jj, ii])], dim=0).unique(
|
||||
dim=1
|
||||
)
|
||||
adj_values = torch.ones(adj.shape[1]).to(verts)
|
||||
|
||||
# Diagonal indices
|
||||
diag_idx = adj[0]
|
||||
|
||||
# Build the sparse matrix
|
||||
idx = torch.cat((adj, torch.stack((diag_idx, diag_idx), dim=0)), dim=1)
|
||||
values = torch.cat((-adj_values, adj_values))
|
||||
|
||||
# The coalesce operation sums the duplicate indices, resulting in the
|
||||
# correct diagonal
|
||||
return torch.sparse_coo_tensor(idx, values, (V, V)).coalesce()
|
||||
|
||||
def laplacian(self) -> Float[Tensor, ""]:
|
||||
with torch.no_grad():
|
||||
L = self._laplacian_uniform()
|
||||
loss = L.mm(self.v_pos)
|
||||
loss = loss.norm(dim=1)
|
||||
loss = loss.mean()
|
||||
return loss
|
78
svrm/ldm/modules/rendering_neus/rasterize.py
Normal file
@ -0,0 +1,78 @@
|
||||
import nvdiffrast.torch as dr
|
||||
import torch
|
||||
|
||||
from ...utils.typing import *
|
||||
|
||||
|
||||
class NVDiffRasterizerContext:
|
||||
def __init__(self, context_type: str, device: torch.device) -> None:
|
||||
self.device = device
|
||||
self.ctx = self.initialize_context(context_type, device)
|
||||
|
||||
def initialize_context(
|
||||
self, context_type: str, device: torch.device
|
||||
) -> Union[dr.RasterizeGLContext, dr.RasterizeCudaContext]:
|
||||
if context_type == "gl":
|
||||
return dr.RasterizeGLContext(device=device)
|
||||
elif context_type == "cuda":
|
||||
return dr.RasterizeCudaContext(device=device)
|
||||
else:
|
||||
raise ValueError(f"Unknown rasterizer context type: {context_type}")
|
||||
|
||||
def vertex_transform(
|
||||
self, verts: Float[Tensor, "Nv 3"], mvp_mtx: Float[Tensor, "B 4 4"]
|
||||
) -> Float[Tensor, "B Nv 4"]:
|
||||
verts_homo = torch.cat(
|
||||
[verts, torch.ones([verts.shape[0], 1]).to(verts)], dim=-1
|
||||
)
|
||||
return torch.matmul(verts_homo, mvp_mtx.permute(0, 2, 1))
|
||||
|
||||
def rasterize(
|
||||
self,
|
||||
pos: Float[Tensor, "B Nv 4"],
|
||||
tri: Integer[Tensor, "Nf 3"],
|
||||
resolution: Union[int, Tuple[int, int]],
|
||||
):
|
||||
# rasterize in instance mode (single topology)
|
||||
return dr.rasterize(self.ctx, pos.float(), tri.int(), resolution, grad_db=True)
|
||||
|
||||
def rasterize_one(
|
||||
self,
|
||||
pos: Float[Tensor, "Nv 4"],
|
||||
tri: Integer[Tensor, "Nf 3"],
|
||||
resolution: Union[int, Tuple[int, int]],
|
||||
):
|
||||
# rasterize one single mesh under a single viewpoint
|
||||
rast, rast_db = self.rasterize(pos[None, ...], tri, resolution)
|
||||
return rast[0], rast_db[0]
|
||||
|
||||
def antialias(
|
||||
self,
|
||||
color: Float[Tensor, "B H W C"],
|
||||
rast: Float[Tensor, "B H W 4"],
|
||||
pos: Float[Tensor, "B Nv 4"],
|
||||
tri: Integer[Tensor, "Nf 3"],
|
||||
) -> Float[Tensor, "B H W C"]:
|
||||
return dr.antialias(color.float(), rast, pos.float(), tri.int())
|
||||
|
||||
def interpolate(
|
||||
self,
|
||||
attr: Float[Tensor, "B Nv C"],
|
||||
rast: Float[Tensor, "B H W 4"],
|
||||
tri: Integer[Tensor, "Nf 3"],
|
||||
rast_db=None,
|
||||
diff_attrs=None,
|
||||
) -> Float[Tensor, "B H W C"]:
|
||||
return dr.interpolate(
|
||||
attr.float(), rast, tri.int(), rast_db=rast_db, diff_attrs=diff_attrs
|
||||
)
|
||||
|
||||
def interpolate_one(
|
||||
self,
|
||||
attr: Float[Tensor, "Nv C"],
|
||||
rast: Float[Tensor, "B H W 4"],
|
||||
tri: Integer[Tensor, "Nf 3"],
|
||||
rast_db=None,
|
||||
diff_attrs=None,
|
||||
) -> Float[Tensor, "B H W C"]:
|
||||
return self.interpolate(attr[None, ...], rast, tri, rast_db, diff_attrs)
|
277
svrm/ldm/modules/rendering_neus/synthesizer.py
Normal file
@ -0,0 +1,277 @@
|
||||
# ORIGINAL LICENSE
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# Modified by Zexin He
|
||||
# The modifications are subject to the same license as the original.
|
||||
|
||||
|
||||
import itertools
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils.renderer import ImportanceRenderer, sample_from_planes
|
||||
from .utils.ray_sampler import RaySampler
|
||||
from ...utils.ops import get_rank
|
||||
|
||||
|
||||
class OSGDecoder(nn.Module):
|
||||
"""
|
||||
Triplane decoder that gives RGB and sigma values from sampled features.
|
||||
Using ReLU here instead of Softplus in the original implementation.
|
||||
|
||||
Reference:
|
||||
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112
|
||||
"""
|
||||
def __init__(self, n_features: int,
|
||||
hidden_dim: int = 64,
|
||||
num_layers: int = 2,
|
||||
activation: nn.Module = nn.ReLU,
|
||||
sdf_bias='sphere',
|
||||
sdf_bias_params=0.5,
|
||||
output_normal=True,
|
||||
normal_type='finite_difference'):
|
||||
super().__init__()
|
||||
self.sdf_bias = sdf_bias
|
||||
self.sdf_bias_params = sdf_bias_params
|
||||
self.output_normal = output_normal
|
||||
self.normal_type = normal_type
|
||||
self.net = nn.Sequential(
|
||||
nn.Linear(3 * n_features, hidden_dim),
|
||||
activation(),
|
||||
*itertools.chain(*[[
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
activation(),
|
||||
] for _ in range(num_layers - 2)]),
|
||||
nn.Linear(hidden_dim, 1 + 3),
|
||||
)
|
||||
# init all bias to zero
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, ray_directions, sample_coordinates, plane_axes, planes, options):
|
||||
# Aggregate features by mean
|
||||
# sampled_features = sampled_features.mean(1)
|
||||
# Aggregate features by concatenation
|
||||
# torch.set_grad_enabled(True)
|
||||
# sample_coordinates.requires_grad_(True)
|
||||
|
||||
sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp'])
|
||||
|
||||
|
||||
_N, n_planes, _M, _C = sampled_features.shape
|
||||
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
|
||||
x = sampled_features
|
||||
|
||||
N, M, C = x.shape
|
||||
# x = x.contiguous().view(N*M, C)
|
||||
|
||||
x = self.net(x)
|
||||
x = x.view(N, M, -1)
|
||||
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
|
||||
|
||||
sdf = x[..., 0:1]
|
||||
# import ipdb; ipdb.set_trace()
|
||||
# print(f'sample_coordinates shape: {sample_coordinates.shape}')
|
||||
# sdf = self.get_shifted_sdf(sample_coordinates, sdf)
|
||||
|
||||
# calculate normal
|
||||
eps = 0.01
|
||||
offsets = torch.as_tensor(
|
||||
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
|
||||
).to(sample_coordinates)
|
||||
points_offset = (
|
||||
sample_coordinates[..., None, :] + offsets # Float[Tensor, "... 3 3"]
|
||||
).clamp(options['sampler_bbox_min'], options['sampler_bbox_max'])
|
||||
|
||||
sdf_offset_list = [self.forward_sdf(
|
||||
plane_axes,
|
||||
planes,
|
||||
points_offset[:,:,i,:],
|
||||
options
|
||||
).unsqueeze(-2) for i in range(points_offset.shape[-2])] # Float[Tensor, "... 3 1"]
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
sdf_offset = torch.cat(sdf_offset_list, -2)
|
||||
sdf_grad = (sdf_offset[..., 0::1, 0] - sdf) / eps
|
||||
|
||||
normal = F.normalize(sdf_grad, dim=-1).to(sdf.dtype)
|
||||
return {'rgb': rgb, 'sdf': sdf, 'normal': normal, 'sdf_grad': sdf_grad}
|
||||
|
||||
def forward_sdf(self, plane_axes, planes, points_offset, options):
|
||||
|
||||
sampled_features = sample_from_planes(plane_axes, planes, points_offset, padding_mode='zeros', box_warp=options['box_warp'])
|
||||
_N, n_planes, _M, _C = sampled_features.shape
|
||||
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C)
|
||||
x = sampled_features
|
||||
|
||||
N, M, C = x.shape
|
||||
# x = x.contiguous().view(N*M, C)
|
||||
|
||||
x = self.net(x)
|
||||
x = x.view(N, M, -1)
|
||||
sdf = x[..., 0:1]
|
||||
# sdf = self.get_shifted_sdf(points_offset, sdf)
|
||||
return sdf
|
||||
|
||||
def get_shifted_sdf(
|
||||
self, points, sdf
|
||||
):
|
||||
if self.sdf_bias == "sphere":
|
||||
assert isinstance(self.sdf_bias_params, float)
|
||||
radius = self.sdf_bias_params
|
||||
sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius
|
||||
else:
|
||||
raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}")
|
||||
return sdf + sdf_bias.to(sdf.dtype)
|
||||
|
||||
|
||||
class TriplaneSynthesizer(nn.Module):
|
||||
"""
|
||||
Synthesizer that renders a triplane volume with planes and a camera.
|
||||
|
||||
Reference:
|
||||
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19
|
||||
"""
|
||||
|
||||
DEFAULT_RENDERING_KWARGS = {
|
||||
'ray_start': 'auto',
|
||||
'ray_end': 'auto',
|
||||
'box_warp': 1.2,
|
||||
# 'box_warp': 1.,
|
||||
'white_back': True,
|
||||
'disparity_space_sampling': False,
|
||||
'clamp_mode': 'softplus',
|
||||
# 'sampler_bbox_min': -1,
|
||||
# 'sampler_bbox_max': 1.,
|
||||
'sampler_bbox_min': -0.6,
|
||||
'sampler_bbox_max': 0.6,
|
||||
}
|
||||
print('DEFAULT_RENDERING_KWARGS')
|
||||
print(DEFAULT_RENDERING_KWARGS)
|
||||
|
||||
|
||||
def __init__(self, triplane_dim: int, samples_per_ray: int, osg_decoder='default'):
|
||||
super().__init__()
|
||||
|
||||
# attributes
|
||||
self.triplane_dim = triplane_dim
|
||||
self.rendering_kwargs = {
|
||||
**self.DEFAULT_RENDERING_KWARGS,
|
||||
'depth_resolution': samples_per_ray,
|
||||
'depth_resolution_importance': 0
|
||||
# 'depth_resolution': samples_per_ray // 2,
|
||||
# 'depth_resolution_importance': samples_per_ray // 2,
|
||||
}
|
||||
|
||||
# renderings
|
||||
self.renderer = ImportanceRenderer()
|
||||
self.ray_sampler = RaySampler()
|
||||
# modules
|
||||
if osg_decoder == 'default':
|
||||
self.decoder = OSGDecoder(n_features=triplane_dim)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, planes, ray_origins, ray_directions, render_size, bgcolor=None):
|
||||
# planes: (N, 3, D', H', W')
|
||||
# render_size: int
|
||||
assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional"
|
||||
|
||||
|
||||
# Perform volume rendering
|
||||
rgb_samples, depth_samples, weights_samples, sdf_grad, normal_samples = self.renderer(
|
||||
planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs, bgcolor
|
||||
)
|
||||
N = planes.shape[0]
|
||||
|
||||
# zhaohx : add for normals
|
||||
normal_samples = F.normalize(normal_samples, dim=-1)
|
||||
normal_samples = (normal_samples + 1.0) / 2.0 # for visualization
|
||||
normal_samples = torch.lerp(torch.zeros_like(normal_samples), normal_samples, weights_samples)
|
||||
|
||||
# Reshape into 'raw' neural-rendered image
|
||||
Himg = Wimg = render_size
|
||||
rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, rgb_samples.shape[-1], Himg, Wimg).contiguous()
|
||||
depth_images = depth_samples.permute(0, 2, 1).reshape(N, 1, Himg, Wimg)
|
||||
weight_images = weights_samples.permute(0, 2, 1).reshape(N, 1, Himg, Wimg)
|
||||
|
||||
# zhaohx : add for normals
|
||||
normal_images = normal_samples.permute(0, 2, 1).reshape(N, normal_samples.shape[-1], Himg, Wimg).contiguous()
|
||||
|
||||
# return {
|
||||
# 'images_rgb': rgb_images,
|
||||
# 'images_depth': depth_images,
|
||||
# 'images_weight': weight_images,
|
||||
# }
|
||||
|
||||
return {
|
||||
'comp_rgb': rgb_images,
|
||||
'comp_depth': depth_images,
|
||||
'opacity': weight_images,
|
||||
'sdf_grad': sdf_grad,
|
||||
'comp_normal': normal_images
|
||||
}
|
||||
# 输出normal的话在这个return里加
|
||||
|
||||
def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None):
|
||||
# planes: (N, 3, D', H', W')
|
||||
# grid_size: int
|
||||
# aabb: (N, 2, 3)
|
||||
if aabb is None:
|
||||
aabb = torch.tensor([
|
||||
[self.rendering_kwargs['sampler_bbox_min']] * 3,
|
||||
[self.rendering_kwargs['sampler_bbox_max']] * 3,
|
||||
], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1)
|
||||
assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb"
|
||||
N = planes.shape[0]
|
||||
|
||||
# create grid points for triplane query
|
||||
grid_points = []
|
||||
for i in range(N):
|
||||
grid_points.append(torch.stack(torch.meshgrid(
|
||||
torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device),
|
||||
torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device),
|
||||
torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device),
|
||||
indexing='ij',
|
||||
), dim=-1).reshape(-1, 3))
|
||||
cube_grid = torch.stack(grid_points, dim=0).to(planes.device)
|
||||
|
||||
features = self.forward_points(planes, cube_grid)
|
||||
|
||||
# reshape into grid
|
||||
features = {
|
||||
k: v.reshape(N, grid_size, grid_size, grid_size, -1)
|
||||
for k, v in features.items()
|
||||
}
|
||||
return features
|
||||
|
||||
def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20):
|
||||
# planes: (N, 3, D', H', W')
|
||||
# points: (N, P, 3)
|
||||
N, P = points.shape[:2]
|
||||
|
||||
# query triplane in chunks
|
||||
outs = []
|
||||
for i in range(0, points.shape[1], chunk_size):
|
||||
chunk_points = points[:, i:i+chunk_size]
|
||||
|
||||
# query triplane
|
||||
# chunk_out = self.renderer.run_model_activated(
|
||||
chunk_out = self.renderer.run_model(
|
||||
planes=planes,
|
||||
decoder=self.decoder,
|
||||
sample_coordinates=chunk_points,
|
||||
sample_directions=torch.zeros_like(chunk_points),
|
||||
options=self.rendering_kwargs,
|
||||
)
|
||||
outs.append(chunk_out)
|
||||
|
||||
# concatenate the outputs
|
||||
point_features = {
|
||||
k: torch.cat([out[k] for out in outs], dim=1)
|
||||
for k in outs[0].keys()
|
||||
}
|
||||
return point_features
|
11
svrm/ldm/modules/rendering_neus/third_party/__init__.py
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
# empty
|
159
svrm/ldm/modules/rendering_neus/third_party/custom_ops.py
vendored
Normal file
@ -0,0 +1,159 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import glob
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
import torch
|
||||
import torch.utils.cpp_extension
|
||||
from torch.utils.file_baton import FileBaton
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Global options.
|
||||
|
||||
verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full'
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Internal helper funcs.
|
||||
|
||||
def _find_compiler_bindir():
|
||||
patterns = [
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64',
|
||||
'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin',
|
||||
]
|
||||
for pattern in patterns:
|
||||
matches = sorted(glob.glob(pattern))
|
||||
if len(matches):
|
||||
return matches[-1]
|
||||
return None
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _get_mangled_gpu_name():
|
||||
name = torch.cuda.get_device_name().lower()
|
||||
out = []
|
||||
for c in name:
|
||||
if re.match('[a-z0-9_-]+', c):
|
||||
out.append(c)
|
||||
else:
|
||||
out.append('-')
|
||||
return ''.join(out)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Main entry point for compiling and loading C++/CUDA plugins.
|
||||
|
||||
_cached_plugins = dict()
|
||||
|
||||
def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs):
|
||||
assert verbosity in ['none', 'brief', 'full']
|
||||
if headers is None:
|
||||
headers = []
|
||||
if source_dir is not None:
|
||||
sources = [os.path.join(source_dir, fname) for fname in sources]
|
||||
headers = [os.path.join(source_dir, fname) for fname in headers]
|
||||
|
||||
# Already cached?
|
||||
if module_name in _cached_plugins:
|
||||
return _cached_plugins[module_name]
|
||||
|
||||
# Print status.
|
||||
if verbosity == 'full':
|
||||
print(f'Setting up PyTorch plugin "{module_name}"...')
|
||||
elif verbosity == 'brief':
|
||||
print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True)
|
||||
verbose_build = (verbosity == 'full')
|
||||
|
||||
# Compile and load.
|
||||
try: # pylint: disable=too-many-nested-blocks
|
||||
# Make sure we can find the necessary compiler binaries.
|
||||
if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0:
|
||||
compiler_bindir = _find_compiler_bindir()
|
||||
if compiler_bindir is None:
|
||||
raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".')
|
||||
os.environ['PATH'] += ';' + compiler_bindir
|
||||
|
||||
# Some containers set TORCH_CUDA_ARCH_LIST to a list that can either
|
||||
# break the build or unnecessarily restrict what's available to nvcc.
|
||||
# Unset it to let nvcc decide based on what's available on the
|
||||
# machine.
|
||||
os.environ['TORCH_CUDA_ARCH_LIST'] = ''
|
||||
|
||||
# Incremental build md5sum trickery. Copies all the input source files
|
||||
# into a cached build directory under a combined md5 digest of the input
|
||||
# source files. Copying is done only if the combined digest has changed.
|
||||
# This keeps input file timestamps and filenames the same as in previous
|
||||
# extension builds, allowing for fast incremental rebuilds.
|
||||
#
|
||||
# This optimization is done only in case all the source files reside in
|
||||
# a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR
|
||||
# environment variable is set (we take this as a signal that the user
|
||||
# actually cares about this.)
|
||||
#
|
||||
# EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work
|
||||
# around the *.cu dependency bug in ninja config.
|
||||
#
|
||||
all_source_files = sorted(sources + headers)
|
||||
all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files)
|
||||
if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ):
|
||||
|
||||
# Compute combined hash digest for all source files.
|
||||
hash_md5 = hashlib.md5()
|
||||
for src in all_source_files:
|
||||
with open(src, 'rb') as f:
|
||||
hash_md5.update(f.read())
|
||||
|
||||
# Select cached build directory name.
|
||||
source_digest = hash_md5.hexdigest()
|
||||
build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access
|
||||
cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}')
|
||||
|
||||
if not os.path.isdir(cached_build_dir):
|
||||
tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}'
|
||||
os.makedirs(tmpdir)
|
||||
for src in all_source_files:
|
||||
shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src)))
|
||||
try:
|
||||
os.replace(tmpdir, cached_build_dir) # atomic
|
||||
except OSError:
|
||||
# source directory already exists, delete tmpdir and its contents.
|
||||
shutil.rmtree(tmpdir)
|
||||
if not os.path.isdir(cached_build_dir): raise
|
||||
|
||||
# Compile.
|
||||
cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources]
|
||||
torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir,
|
||||
verbose=verbose_build, sources=cached_sources, **build_kwargs)
|
||||
else:
|
||||
torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs)
|
||||
|
||||
# Load.
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
except:
|
||||
if verbosity == 'brief':
|
||||
print('Failed!')
|
||||
raise
|
||||
|
||||
# Print status and add to cache dict.
|
||||
if verbosity == 'full':
|
||||
print(f'Done setting up PyTorch plugin "{module_name}".')
|
||||
elif verbosity == 'brief':
|
||||
print('Done.')
|
||||
_cached_plugins[module_name] = module
|
||||
return module
|
||||
|
||||
#----------------------------------------------------------------------------
|
11
svrm/ldm/modules/rendering_neus/third_party/dnnlib/__init__.py
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
from .util import EasyDict, make_cache_dir_path
|
493
svrm/ldm/modules/rendering_neus/third_party/dnnlib/util.py
vendored
Normal file
@ -0,0 +1,493 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
"""Miscellaneous utility classes and functions."""
|
||||
|
||||
import ctypes
|
||||
import fnmatch
|
||||
import importlib
|
||||
import inspect
|
||||
import numpy as np
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import types
|
||||
import io
|
||||
import pickle
|
||||
import re
|
||||
import requests
|
||||
import html
|
||||
import hashlib
|
||||
import glob
|
||||
import tempfile
|
||||
import urllib
|
||||
import urllib.request
|
||||
import uuid
|
||||
|
||||
from distutils.util import strtobool
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
|
||||
# Util classes
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EasyDict(dict):
|
||||
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
try:
|
||||
return self[name]
|
||||
except KeyError:
|
||||
raise AttributeError(name)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
self[name] = value
|
||||
|
||||
def __delattr__(self, name: str) -> None:
|
||||
del self[name]
|
||||
|
||||
|
||||
class Logger(object):
|
||||
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
||||
|
||||
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
||||
self.file = None
|
||||
|
||||
if file_name is not None:
|
||||
self.file = open(file_name, file_mode)
|
||||
|
||||
self.should_flush = should_flush
|
||||
self.stdout = sys.stdout
|
||||
self.stderr = sys.stderr
|
||||
|
||||
sys.stdout = self
|
||||
sys.stderr = self
|
||||
|
||||
def __enter__(self) -> "Logger":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def write(self, text: Union[str, bytes]) -> None:
|
||||
"""Write text to stdout (and a file) and optionally flush."""
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode()
|
||||
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
||||
return
|
||||
|
||||
if self.file is not None:
|
||||
self.file.write(text)
|
||||
|
||||
self.stdout.write(text)
|
||||
|
||||
if self.should_flush:
|
||||
self.flush()
|
||||
|
||||
def flush(self) -> None:
|
||||
"""Flush written text to both stdout and a file, if open."""
|
||||
if self.file is not None:
|
||||
self.file.flush()
|
||||
|
||||
self.stdout.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
||||
self.flush()
|
||||
|
||||
# if using multiple loggers, prevent closing in wrong order
|
||||
if sys.stdout is self:
|
||||
sys.stdout = self.stdout
|
||||
if sys.stderr is self:
|
||||
sys.stderr = self.stderr
|
||||
|
||||
if self.file is not None:
|
||||
self.file.close()
|
||||
self.file = None
|
||||
|
||||
|
||||
# Cache directories
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
_dnnlib_cache_dir = None
|
||||
|
||||
def set_cache_dir(path: str) -> None:
|
||||
global _dnnlib_cache_dir
|
||||
_dnnlib_cache_dir = path
|
||||
|
||||
def make_cache_dir_path(*paths: str) -> str:
|
||||
if _dnnlib_cache_dir is not None:
|
||||
return os.path.join(_dnnlib_cache_dir, *paths)
|
||||
if 'DNNLIB_CACHE_DIR' in os.environ:
|
||||
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
||||
if 'HOME' in os.environ:
|
||||
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
||||
if 'USERPROFILE' in os.environ:
|
||||
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
||||
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
||||
|
||||
# Small util functions
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
def format_time(seconds: Union[int, float]) -> str:
|
||||
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
||||
s = int(np.rint(seconds))
|
||||
|
||||
if s < 60:
|
||||
return "{0}s".format(s)
|
||||
elif s < 60 * 60:
|
||||
return "{0}m {1:02}s".format(s // 60, s % 60)
|
||||
elif s < 24 * 60 * 60:
|
||||
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
||||
else:
|
||||
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
||||
|
||||
|
||||
def format_time_brief(seconds: Union[int, float]) -> str:
|
||||
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
||||
s = int(np.rint(seconds))
|
||||
|
||||
if s < 60:
|
||||
return "{0}s".format(s)
|
||||
elif s < 60 * 60:
|
||||
return "{0}m {1:02}s".format(s // 60, s % 60)
|
||||
elif s < 24 * 60 * 60:
|
||||
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
||||
else:
|
||||
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
||||
|
||||
|
||||
def ask_yes_no(question: str) -> bool:
|
||||
"""Ask the user the question until the user inputs a valid answer."""
|
||||
while True:
|
||||
try:
|
||||
print("{0} [y/n]".format(question))
|
||||
return strtobool(input().lower())
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def tuple_product(t: Tuple) -> Any:
|
||||
"""Calculate the product of the tuple elements."""
|
||||
result = 1
|
||||
|
||||
for v in t:
|
||||
result *= v
|
||||
|
||||
return result
|
||||
|
||||
|
||||
_str_to_ctype = {
|
||||
"uint8": ctypes.c_ubyte,
|
||||
"uint16": ctypes.c_uint16,
|
||||
"uint32": ctypes.c_uint32,
|
||||
"uint64": ctypes.c_uint64,
|
||||
"int8": ctypes.c_byte,
|
||||
"int16": ctypes.c_int16,
|
||||
"int32": ctypes.c_int32,
|
||||
"int64": ctypes.c_int64,
|
||||
"float32": ctypes.c_float,
|
||||
"float64": ctypes.c_double
|
||||
}
|
||||
|
||||
|
||||
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
||||
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
||||
type_str = None
|
||||
|
||||
if isinstance(type_obj, str):
|
||||
type_str = type_obj
|
||||
elif hasattr(type_obj, "__name__"):
|
||||
type_str = type_obj.__name__
|
||||
elif hasattr(type_obj, "name"):
|
||||
type_str = type_obj.name
|
||||
else:
|
||||
raise RuntimeError("Cannot infer type name from input")
|
||||
|
||||
assert type_str in _str_to_ctype.keys()
|
||||
|
||||
my_dtype = np.dtype(type_str)
|
||||
my_ctype = _str_to_ctype[type_str]
|
||||
|
||||
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
||||
|
||||
return my_dtype, my_ctype
|
||||
|
||||
|
||||
def is_pickleable(obj: Any) -> bool:
|
||||
try:
|
||||
with io.BytesIO() as stream:
|
||||
pickle.dump(obj, stream)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
# Functionality to import modules/objects by name, and call functions by name
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
||||
"""Searches for the underlying module behind the name to some python object.
|
||||
Returns the module and the object name (original name with module part removed)."""
|
||||
|
||||
# allow convenience shorthands, substitute them by full names
|
||||
obj_name = re.sub("^np.", "numpy.", obj_name)
|
||||
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
||||
|
||||
# list alternatives for (module_name, local_obj_name)
|
||||
parts = obj_name.split(".")
|
||||
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
||||
|
||||
# try each alternative in turn
|
||||
for module_name, local_obj_name in name_pairs:
|
||||
try:
|
||||
module = importlib.import_module(module_name) # may raise ImportError
|
||||
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
||||
return module, local_obj_name
|
||||
except:
|
||||
pass
|
||||
|
||||
# maybe some of the modules themselves contain errors?
|
||||
for module_name, _local_obj_name in name_pairs:
|
||||
try:
|
||||
importlib.import_module(module_name) # may raise ImportError
|
||||
except ImportError:
|
||||
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
||||
raise
|
||||
|
||||
# maybe the requested attribute is missing?
|
||||
for module_name, local_obj_name in name_pairs:
|
||||
try:
|
||||
module = importlib.import_module(module_name) # may raise ImportError
|
||||
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# we are out of luck, but we have no idea why
|
||||
raise ImportError(obj_name)
|
||||
|
||||
|
||||
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
||||
"""Traverses the object name and returns the last (rightmost) python object."""
|
||||
if obj_name == '':
|
||||
return module
|
||||
obj = module
|
||||
for part in obj_name.split("."):
|
||||
obj = getattr(obj, part)
|
||||
return obj
|
||||
|
||||
|
||||
def get_obj_by_name(name: str) -> Any:
|
||||
"""Finds the python object with the given name."""
|
||||
module, obj_name = get_module_from_obj_name(name)
|
||||
return get_obj_from_module(module, obj_name)
|
||||
|
||||
|
||||
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
||||
"""Finds the python object with the given name and calls it as a function."""
|
||||
assert func_name is not None
|
||||
func_obj = get_obj_by_name(func_name)
|
||||
assert callable(func_obj)
|
||||
return func_obj(*args, **kwargs)
|
||||
|
||||
|
||||
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
||||
"""Finds the python class with the given name and constructs it with the given arguments."""
|
||||
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
||||
|
||||
|
||||
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
||||
"""Get the directory path of the module containing the given object name."""
|
||||
module, _ = get_module_from_obj_name(obj_name)
|
||||
return os.path.dirname(inspect.getfile(module))
|
||||
|
||||
|
||||
def is_top_level_function(obj: Any) -> bool:
|
||||
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
||||
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
||||
|
||||
|
||||
def get_top_level_function_name(obj: Any) -> str:
|
||||
"""Return the fully-qualified name of a top-level function."""
|
||||
assert is_top_level_function(obj)
|
||||
module = obj.__module__
|
||||
if module == '__main__':
|
||||
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
||||
return module + "." + obj.__name__
|
||||
|
||||
|
||||
# File system helpers
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
||||
"""List all files recursively in a given directory while ignoring given file and directory names.
|
||||
Returns list of tuples containing both absolute and relative paths."""
|
||||
assert os.path.isdir(dir_path)
|
||||
base_name = os.path.basename(os.path.normpath(dir_path))
|
||||
|
||||
if ignores is None:
|
||||
ignores = []
|
||||
|
||||
result = []
|
||||
|
||||
for root, dirs, files in os.walk(dir_path, topdown=True):
|
||||
for ignore_ in ignores:
|
||||
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
||||
|
||||
# dirs need to be edited in-place
|
||||
for d in dirs_to_remove:
|
||||
dirs.remove(d)
|
||||
|
||||
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
||||
|
||||
absolute_paths = [os.path.join(root, f) for f in files]
|
||||
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
||||
|
||||
if add_base_to_relative:
|
||||
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
||||
|
||||
assert len(absolute_paths) == len(relative_paths)
|
||||
result += zip(absolute_paths, relative_paths)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
||||
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
||||
Will create all necessary directories."""
|
||||
for file in files:
|
||||
target_dir_name = os.path.dirname(file[1])
|
||||
|
||||
# will create all intermediate-level directories
|
||||
if not os.path.exists(target_dir_name):
|
||||
os.makedirs(target_dir_name)
|
||||
|
||||
shutil.copyfile(file[0], file[1])
|
||||
|
||||
|
||||
# URL helpers
|
||||
# ------------------------------------------------------------------------------------------
|
||||
|
||||
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
||||
"""Determine whether the given object is a valid URL string."""
|
||||
if not isinstance(obj, str) or not "://" in obj:
|
||||
return False
|
||||
if allow_file_urls and obj.startswith('file://'):
|
||||
return True
|
||||
try:
|
||||
res = requests.compat.urlparse(obj)
|
||||
if not res.scheme or not res.netloc or not "." in res.netloc:
|
||||
return False
|
||||
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
||||
if not res.scheme or not res.netloc or not "." in res.netloc:
|
||||
return False
|
||||
except:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
||||
"""Download the given URL and return a binary-mode file object to access the data."""
|
||||
assert num_attempts >= 1
|
||||
assert not (return_filename and (not cache))
|
||||
|
||||
# Doesn't look like an URL scheme so interpret it as a local filename.
|
||||
if not re.match('^[a-z]+://', url):
|
||||
return url if return_filename else open(url, "rb")
|
||||
|
||||
# Handle file URLs. This code handles unusual file:// patterns that
|
||||
# arise on Windows:
|
||||
#
|
||||
# file:///c:/foo.txt
|
||||
#
|
||||
# which would translate to a local '/c:/foo.txt' filename that's
|
||||
# invalid. Drop the forward slash for such pathnames.
|
||||
#
|
||||
# If you touch this code path, you should test it on both Linux and
|
||||
# Windows.
|
||||
#
|
||||
# Some internet resources suggest using urllib.request.url2pathname() but
|
||||
# but that converts forward slashes to backslashes and this causes
|
||||
# its own set of problems.
|
||||
if url.startswith('file://'):
|
||||
filename = urllib.parse.urlparse(url).path
|
||||
if re.match(r'^/[a-zA-Z]:', filename):
|
||||
filename = filename[1:]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
assert is_url(url)
|
||||
|
||||
# Lookup from cache.
|
||||
if cache_dir is None:
|
||||
cache_dir = make_cache_dir_path('downloads')
|
||||
|
||||
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
||||
if cache:
|
||||
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
||||
if len(cache_files) == 1:
|
||||
filename = cache_files[0]
|
||||
return filename if return_filename else open(filename, "rb")
|
||||
|
||||
# Download.
|
||||
url_name = None
|
||||
url_data = None
|
||||
with requests.Session() as session:
|
||||
if verbose:
|
||||
print("Downloading %s ..." % url, end="", flush=True)
|
||||
for attempts_left in reversed(range(num_attempts)):
|
||||
try:
|
||||
with session.get(url) as res:
|
||||
res.raise_for_status()
|
||||
if len(res.content) == 0:
|
||||
raise IOError("No data received")
|
||||
|
||||
if len(res.content) < 8192:
|
||||
content_str = res.content.decode("utf-8")
|
||||
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
||||
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
||||
if len(links) == 1:
|
||||
url = requests.compat.urljoin(url, links[0])
|
||||
raise IOError("Google Drive virus checker nag")
|
||||
if "Google Drive - Quota exceeded" in content_str:
|
||||
raise IOError("Google Drive download quota exceeded -- please try again later")
|
||||
|
||||
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
||||
url_name = match[1] if match else url
|
||||
url_data = res.content
|
||||
if verbose:
|
||||
print(" done")
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
raise
|
||||
except:
|
||||
if not attempts_left:
|
||||
if verbose:
|
||||
print(" failed")
|
||||
raise
|
||||
if verbose:
|
||||
print(".", end="", flush=True)
|
||||
|
||||
# Save to cache.
|
||||
if cache:
|
||||
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
||||
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
||||
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
with open(temp_file, "wb") as f:
|
||||
f.write(url_data)
|
||||
os.replace(temp_file, cache_file) # atomic
|
||||
if return_filename:
|
||||
return cache_file
|
||||
|
||||
# Return data as file object.
|
||||
assert not return_filename
|
||||
return io.BytesIO(url_data)
|
268
svrm/ldm/modules/rendering_neus/third_party/misc.py
vendored
Normal file
@ -0,0 +1,268 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
import re
|
||||
import contextlib
|
||||
import numpy as np
|
||||
import torch
|
||||
import warnings
|
||||
from ldm.modules.neus.third_party import dnnlib
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Cached construction of constant tensors. Avoids CPU=>GPU copy when the
|
||||
# same constant is used multiple times.
|
||||
|
||||
_constant_cache = dict()
|
||||
|
||||
def constant(value, shape=None, dtype=None, device=None, memory_format=None):
|
||||
value = np.asarray(value)
|
||||
if shape is not None:
|
||||
shape = tuple(shape)
|
||||
if dtype is None:
|
||||
dtype = torch.get_default_dtype()
|
||||
if device is None:
|
||||
device = torch.device('cpu')
|
||||
if memory_format is None:
|
||||
memory_format = torch.contiguous_format
|
||||
|
||||
key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
|
||||
tensor = _constant_cache.get(key, None)
|
||||
if tensor is None:
|
||||
tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
|
||||
if shape is not None:
|
||||
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
|
||||
tensor = tensor.contiguous(memory_format=memory_format)
|
||||
_constant_cache[key] = tensor
|
||||
return tensor
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Replace NaN/Inf with specified numerical values.
|
||||
|
||||
try:
|
||||
nan_to_num = torch.nan_to_num # 1.8.0a0
|
||||
except AttributeError:
|
||||
def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin
|
||||
assert isinstance(input, torch.Tensor)
|
||||
if posinf is None:
|
||||
posinf = torch.finfo(input.dtype).max
|
||||
if neginf is None:
|
||||
neginf = torch.finfo(input.dtype).min
|
||||
assert nan == 0
|
||||
return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Symbolic assert.
|
||||
|
||||
try:
|
||||
symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
symbolic_assert = torch.Assert # 1.7.0
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Context manager to temporarily suppress known warnings in torch.jit.trace().
|
||||
# Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
|
||||
|
||||
@contextlib.contextmanager
|
||||
def suppress_tracer_warnings():
|
||||
flt = ('ignore', None, torch.jit.TracerWarning, None, 0)
|
||||
warnings.filters.insert(0, flt)
|
||||
yield
|
||||
warnings.filters.remove(flt)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Assert that the shape of a tensor matches the given list of integers.
|
||||
# None indicates that the size of a dimension is allowed to vary.
|
||||
# Performs symbolic assertion when used in torch.jit.trace().
|
||||
|
||||
def assert_shape(tensor, ref_shape):
|
||||
if tensor.ndim != len(ref_shape):
|
||||
raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}')
|
||||
for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)):
|
||||
if ref_size is None:
|
||||
pass
|
||||
elif isinstance(ref_size, torch.Tensor):
|
||||
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
||||
symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}')
|
||||
elif isinstance(size, torch.Tensor):
|
||||
with suppress_tracer_warnings(): # as_tensor results are registered as constants
|
||||
symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}')
|
||||
elif size != ref_size:
|
||||
raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}')
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Function decorator that calls torch.autograd.profiler.record_function().
|
||||
|
||||
def profiled_function(fn):
|
||||
def decorator(*args, **kwargs):
|
||||
with torch.autograd.profiler.record_function(fn.__name__):
|
||||
return fn(*args, **kwargs)
|
||||
decorator.__name__ = fn.__name__
|
||||
return decorator
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Sampler for torch.utils.data.DataLoader that loops over the dataset
|
||||
# indefinitely, shuffling items as it goes.
|
||||
|
||||
class InfiniteSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5):
|
||||
assert len(dataset) > 0
|
||||
assert num_replicas > 0
|
||||
assert 0 <= rank < num_replicas
|
||||
assert 0 <= window_size <= 1
|
||||
super().__init__(dataset)
|
||||
self.dataset = dataset
|
||||
self.rank = rank
|
||||
self.num_replicas = num_replicas
|
||||
self.shuffle = shuffle
|
||||
self.seed = seed
|
||||
self.window_size = window_size
|
||||
|
||||
def __iter__(self):
|
||||
order = np.arange(len(self.dataset))
|
||||
rnd = None
|
||||
window = 0
|
||||
if self.shuffle:
|
||||
rnd = np.random.RandomState(self.seed)
|
||||
rnd.shuffle(order)
|
||||
window = int(np.rint(order.size * self.window_size))
|
||||
|
||||
idx = 0
|
||||
while True:
|
||||
i = idx % order.size
|
||||
if idx % self.num_replicas == self.rank:
|
||||
yield order[i]
|
||||
if window >= 2:
|
||||
j = (i - rnd.randint(window)) % order.size
|
||||
order[i], order[j] = order[j], order[i]
|
||||
idx += 1
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Utilities for operating with torch.nn.Module parameters and buffers.
|
||||
|
||||
def params_and_buffers(module):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return list(module.parameters()) + list(module.buffers())
|
||||
|
||||
def named_params_and_buffers(module):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
return list(module.named_parameters()) + list(module.named_buffers())
|
||||
|
||||
def copy_params_and_buffers(src_module, dst_module, require_all=False):
|
||||
assert isinstance(src_module, torch.nn.Module)
|
||||
assert isinstance(dst_module, torch.nn.Module)
|
||||
src_tensors = dict(named_params_and_buffers(src_module))
|
||||
for name, tensor in named_params_and_buffers(dst_module):
|
||||
assert (name in src_tensors) or (not require_all)
|
||||
if name in src_tensors:
|
||||
tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Context manager for easily enabling/disabling DistributedDataParallel
|
||||
# synchronization.
|
||||
|
||||
@contextlib.contextmanager
|
||||
def ddp_sync(module, sync):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
||||
yield
|
||||
else:
|
||||
with module.no_sync():
|
||||
yield
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Check DistributedDataParallel consistency across processes.
|
||||
|
||||
def check_ddp_consistency(module, ignore_regex=None):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
for name, tensor in named_params_and_buffers(module):
|
||||
fullname = type(module).__name__ + '.' + name
|
||||
if ignore_regex is not None and re.fullmatch(ignore_regex, fullname):
|
||||
continue
|
||||
tensor = tensor.detach()
|
||||
if tensor.is_floating_point():
|
||||
tensor = nan_to_num(tensor)
|
||||
other = tensor.clone()
|
||||
torch.distributed.broadcast(tensor=other, src=0)
|
||||
assert (tensor == other).all(), fullname
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
# Print summary table of module hierarchy.
|
||||
|
||||
def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True):
|
||||
assert isinstance(module, torch.nn.Module)
|
||||
assert not isinstance(module, torch.jit.ScriptModule)
|
||||
assert isinstance(inputs, (tuple, list))
|
||||
|
||||
# Register hooks.
|
||||
entries = []
|
||||
nesting = [0]
|
||||
def pre_hook(_mod, _inputs):
|
||||
nesting[0] += 1
|
||||
def post_hook(mod, _inputs, outputs):
|
||||
nesting[0] -= 1
|
||||
if nesting[0] <= max_nesting:
|
||||
outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs]
|
||||
outputs = [t for t in outputs if isinstance(t, torch.Tensor)]
|
||||
entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs))
|
||||
hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()]
|
||||
hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()]
|
||||
|
||||
# Run module.
|
||||
outputs = module(*inputs)
|
||||
for hook in hooks:
|
||||
hook.remove()
|
||||
|
||||
# Identify unique outputs, parameters, and buffers.
|
||||
tensors_seen = set()
|
||||
for e in entries:
|
||||
e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen]
|
||||
e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen]
|
||||
e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen]
|
||||
tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs}
|
||||
|
||||
# Filter out redundant entries.
|
||||
if skip_redundant:
|
||||
entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)]
|
||||
|
||||
# Construct table.
|
||||
rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']]
|
||||
rows += [['---'] * len(rows[0])]
|
||||
param_total = 0
|
||||
buffer_total = 0
|
||||
submodule_names = {mod: name for name, mod in module.named_modules()}
|
||||
for e in entries:
|
||||
name = '<top-level>' if e.mod is module else submodule_names[e.mod]
|
||||
param_size = sum(t.numel() for t in e.unique_params)
|
||||
buffer_size = sum(t.numel() for t in e.unique_buffers)
|
||||
output_shapes = [str(list(t.shape)) for t in e.outputs]
|
||||
output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs]
|
||||
rows += [[
|
||||
name + (':0' if len(e.outputs) >= 2 else ''),
|
||||
str(param_size) if param_size else '-',
|
||||
str(buffer_size) if buffer_size else '-',
|
||||
(output_shapes + ['-'])[0],
|
||||
(output_dtypes + ['-'])[0],
|
||||
]]
|
||||
for idx in range(1, len(e.outputs)):
|
||||
rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]]
|
||||
param_total += param_size
|
||||
buffer_total += buffer_size
|
||||
rows += [['---'] * len(rows[0])]
|
||||
rows += [['Total', str(param_total), str(buffer_total), '-', '-']]
|
||||
|
||||
# Print table.
|
||||
widths = [max(len(cell) for cell in column) for column in zip(*rows)]
|
||||
print()
|
||||
for row in rows:
|
||||
print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths)))
|
||||
print()
|
||||
return outputs
|
||||
|
||||
#----------------------------------------------------------------------------
|
11
svrm/ldm/modules/rendering_neus/third_party/ops/__init__.py
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
# empty
|
103
svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cpp
vendored
Normal file
@ -0,0 +1,103 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include "bias_act.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static bool has_same_layout(torch::Tensor x, torch::Tensor y)
|
||||
{
|
||||
if (x.dim() != y.dim())
|
||||
return false;
|
||||
for (int64_t i = 0; i < x.dim(); i++)
|
||||
{
|
||||
if (x.size(i) != y.size(i))
|
||||
return false;
|
||||
if (x.size(i) >= 2 && x.stride(i) != y.stride(i))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp)
|
||||
{
|
||||
// Validate arguments.
|
||||
TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device");
|
||||
TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x");
|
||||
TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x");
|
||||
TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x");
|
||||
TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x");
|
||||
TORCH_CHECK(x.numel() <= INT_MAX, "x is too large");
|
||||
TORCH_CHECK(b.dim() == 1, "b must have rank 1");
|
||||
TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds");
|
||||
TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements");
|
||||
TORCH_CHECK(grad >= 0, "grad must be non-negative");
|
||||
|
||||
// Validate layout.
|
||||
TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense");
|
||||
TORCH_CHECK(b.is_contiguous(), "b must be contiguous");
|
||||
TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x");
|
||||
TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x");
|
||||
TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x");
|
||||
|
||||
// Create output tensor.
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
torch::Tensor y = torch::empty_like(x);
|
||||
TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x");
|
||||
|
||||
// Initialize CUDA kernel parameters.
|
||||
bias_act_kernel_params p;
|
||||
p.x = x.data_ptr();
|
||||
p.b = (b.numel()) ? b.data_ptr() : NULL;
|
||||
p.xref = (xref.numel()) ? xref.data_ptr() : NULL;
|
||||
p.yref = (yref.numel()) ? yref.data_ptr() : NULL;
|
||||
p.dy = (dy.numel()) ? dy.data_ptr() : NULL;
|
||||
p.y = y.data_ptr();
|
||||
p.grad = grad;
|
||||
p.act = act;
|
||||
p.alpha = alpha;
|
||||
p.gain = gain;
|
||||
p.clamp = clamp;
|
||||
p.sizeX = (int)x.numel();
|
||||
p.sizeB = (int)b.numel();
|
||||
p.stepB = (b.numel()) ? (int)x.stride(dim) : 1;
|
||||
|
||||
// Choose CUDA kernel.
|
||||
void* kernel;
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&]
|
||||
{
|
||||
kernel = choose_bias_act_kernel<scalar_t>(p);
|
||||
});
|
||||
TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func");
|
||||
|
||||
// Launch CUDA kernel.
|
||||
p.loopX = 4;
|
||||
int blockSize = 4 * 32;
|
||||
int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1;
|
||||
void* args[] = {&p};
|
||||
AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream()));
|
||||
return y;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
{
|
||||
m.def("bias_act", &bias_act);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
177
svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.cu
vendored
Normal file
@ -0,0 +1,177 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include <c10/util/Half.h>
|
||||
#include "bias_act.h"
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Helpers.
|
||||
|
||||
template <class T> struct InternalType;
|
||||
template <> struct InternalType<double> { typedef double scalar_t; };
|
||||
template <> struct InternalType<float> { typedef float scalar_t; };
|
||||
template <> struct InternalType<c10::Half> { typedef float scalar_t; };
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel.
|
||||
|
||||
template <class T, int A>
|
||||
__global__ void bias_act_kernel(bias_act_kernel_params p)
|
||||
{
|
||||
typedef typename InternalType<T>::scalar_t scalar_t;
|
||||
int G = p.grad;
|
||||
scalar_t alpha = (scalar_t)p.alpha;
|
||||
scalar_t gain = (scalar_t)p.gain;
|
||||
scalar_t clamp = (scalar_t)p.clamp;
|
||||
scalar_t one = (scalar_t)1;
|
||||
scalar_t two = (scalar_t)2;
|
||||
scalar_t expRange = (scalar_t)80;
|
||||
scalar_t halfExpRange = (scalar_t)40;
|
||||
scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946;
|
||||
scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717;
|
||||
|
||||
// Loop over elements.
|
||||
int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x;
|
||||
for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x)
|
||||
{
|
||||
// Load.
|
||||
scalar_t x = (scalar_t)((const T*)p.x)[xi];
|
||||
scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0;
|
||||
scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0;
|
||||
scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0;
|
||||
scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one;
|
||||
scalar_t yy = (gain != 0) ? yref / gain : 0;
|
||||
scalar_t y = 0;
|
||||
|
||||
// Apply bias.
|
||||
((G == 0) ? x : xref) += b;
|
||||
|
||||
// linear
|
||||
if (A == 1)
|
||||
{
|
||||
if (G == 0) y = x;
|
||||
if (G == 1) y = x;
|
||||
}
|
||||
|
||||
// relu
|
||||
if (A == 2)
|
||||
{
|
||||
if (G == 0) y = (x > 0) ? x : 0;
|
||||
if (G == 1) y = (yy > 0) ? x : 0;
|
||||
}
|
||||
|
||||
// lrelu
|
||||
if (A == 3)
|
||||
{
|
||||
if (G == 0) y = (x > 0) ? x : x * alpha;
|
||||
if (G == 1) y = (yy > 0) ? x : x * alpha;
|
||||
}
|
||||
|
||||
// tanh
|
||||
if (A == 4)
|
||||
{
|
||||
if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); }
|
||||
if (G == 1) y = x * (one - yy * yy);
|
||||
if (G == 2) y = x * (one - yy * yy) * (-two * yy);
|
||||
}
|
||||
|
||||
// sigmoid
|
||||
if (A == 5)
|
||||
{
|
||||
if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one);
|
||||
if (G == 1) y = x * yy * (one - yy);
|
||||
if (G == 2) y = x * yy * (one - yy) * (one - two * yy);
|
||||
}
|
||||
|
||||
// elu
|
||||
if (A == 6)
|
||||
{
|
||||
if (G == 0) y = (x >= 0) ? x : exp(x) - one;
|
||||
if (G == 1) y = (yy >= 0) ? x : x * (yy + one);
|
||||
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one);
|
||||
}
|
||||
|
||||
// selu
|
||||
if (A == 7)
|
||||
{
|
||||
if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one);
|
||||
if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha);
|
||||
if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha);
|
||||
}
|
||||
|
||||
// softplus
|
||||
if (A == 8)
|
||||
{
|
||||
if (G == 0) y = (x > expRange) ? x : log(exp(x) + one);
|
||||
if (G == 1) y = x * (one - exp(-yy));
|
||||
if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); }
|
||||
}
|
||||
|
||||
// swish
|
||||
if (A == 9)
|
||||
{
|
||||
if (G == 0)
|
||||
y = (x < -expRange) ? 0 : x / (exp(-x) + one);
|
||||
else
|
||||
{
|
||||
scalar_t c = exp(xref);
|
||||
scalar_t d = c + one;
|
||||
if (G == 1)
|
||||
y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d);
|
||||
else
|
||||
y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d);
|
||||
yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain;
|
||||
}
|
||||
}
|
||||
|
||||
// Apply gain.
|
||||
y *= gain * dy;
|
||||
|
||||
// Clamp.
|
||||
if (clamp >= 0)
|
||||
{
|
||||
if (G == 0)
|
||||
y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp;
|
||||
else
|
||||
y = (yref > -clamp & yref < clamp) ? y : 0;
|
||||
}
|
||||
|
||||
// Store.
|
||||
((T*)p.y)[xi] = (T)y;
|
||||
}
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p)
|
||||
{
|
||||
if (p.act == 1) return (void*)bias_act_kernel<T, 1>;
|
||||
if (p.act == 2) return (void*)bias_act_kernel<T, 2>;
|
||||
if (p.act == 3) return (void*)bias_act_kernel<T, 3>;
|
||||
if (p.act == 4) return (void*)bias_act_kernel<T, 4>;
|
||||
if (p.act == 5) return (void*)bias_act_kernel<T, 5>;
|
||||
if (p.act == 6) return (void*)bias_act_kernel<T, 6>;
|
||||
if (p.act == 7) return (void*)bias_act_kernel<T, 7>;
|
||||
if (p.act == 8) return (void*)bias_act_kernel<T, 8>;
|
||||
if (p.act == 9) return (void*)bias_act_kernel<T, 9>;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// Template specializations.
|
||||
|
||||
template void* choose_bias_act_kernel<double> (const bias_act_kernel_params& p);
|
||||
template void* choose_bias_act_kernel<float> (const bias_act_kernel_params& p);
|
||||
template void* choose_bias_act_kernel<c10::Half> (const bias_act_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
42
svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.h
vendored
Normal file
@ -0,0 +1,42 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel parameters.
|
||||
|
||||
struct bias_act_kernel_params
|
||||
{
|
||||
const void* x; // [sizeX]
|
||||
const void* b; // [sizeB] or NULL
|
||||
const void* xref; // [sizeX] or NULL
|
||||
const void* yref; // [sizeX] or NULL
|
||||
const void* dy; // [sizeX] or NULL
|
||||
void* y; // [sizeX]
|
||||
|
||||
int grad;
|
||||
int act;
|
||||
float alpha;
|
||||
float gain;
|
||||
float clamp;
|
||||
|
||||
int sizeX;
|
||||
int sizeB;
|
||||
int stepB;
|
||||
int loopX;
|
||||
};
|
||||
|
||||
//------------------------------------------------------------------------
|
||||
// CUDA kernel selection.
|
||||
|
||||
template <class T> void* choose_bias_act_kernel(const bias_act_kernel_params& p);
|
||||
|
||||
//------------------------------------------------------------------------
|
211
svrm/ldm/modules/rendering_neus/third_party/ops/bias_act.py
vendored
Normal file
@ -0,0 +1,211 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
"""Custom PyTorch ops for efficient bias and activation."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
from ldm.modules.neus.third_party import dnnlib
|
||||
|
||||
from .. import custom_ops
|
||||
from .. import misc
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
activation_funcs = {
|
||||
'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False),
|
||||
'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False),
|
||||
'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False),
|
||||
'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True),
|
||||
'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True),
|
||||
'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True),
|
||||
'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True),
|
||||
'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True),
|
||||
'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True),
|
||||
}
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_plugin = None
|
||||
_null_tensor = torch.empty([0])
|
||||
|
||||
def _init():
|
||||
global _plugin
|
||||
if _plugin is None:
|
||||
_plugin = custom_ops.get_plugin(
|
||||
module_name='bias_act_plugin',
|
||||
sources=['bias_act.cpp', 'bias_act.cu'],
|
||||
headers=['bias_act.h'],
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math'],
|
||||
)
|
||||
return True
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'):
|
||||
r"""Fused bias and activation function.
|
||||
|
||||
Adds bias `b` to activation tensor `x`, evaluates activation function `act`,
|
||||
and scales the result by `gain`. Each of the steps is optional. In most cases,
|
||||
the fused op is considerably more efficient than performing the same calculation
|
||||
using standard PyTorch ops. It supports first and second order gradients,
|
||||
but not third order gradients.
|
||||
|
||||
Args:
|
||||
x: Input activation tensor. Can be of any shape.
|
||||
b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type
|
||||
as `x`. The shape must be known, and it must match the dimension of `x`
|
||||
corresponding to `dim`.
|
||||
dim: The dimension in `x` corresponding to the elements of `b`.
|
||||
The value of `dim` is ignored if `b` is not specified.
|
||||
act: Name of the activation function to evaluate, or `"linear"` to disable.
|
||||
Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc.
|
||||
See `activation_funcs` for a full list. `None` is not allowed.
|
||||
alpha: Shape parameter for the activation function, or `None` to use the default.
|
||||
gain: Scaling factor for the output tensor, or `None` to use default.
|
||||
See `activation_funcs` for the default scaling of each activation function.
|
||||
If unsure, consider specifying 1.
|
||||
clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable
|
||||
the clamping (default).
|
||||
impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default).
|
||||
|
||||
Returns:
|
||||
Tensor of the same shape and datatype as `x`.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert impl in ['ref', 'cuda']
|
||||
if impl == 'cuda' and x.device.type == 'cuda' and _init():
|
||||
return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b)
|
||||
return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
@misc.profiled_function
|
||||
def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
||||
"""Slow reference implementation of `bias_act()` using standard TensorFlow ops.
|
||||
"""
|
||||
assert isinstance(x, torch.Tensor)
|
||||
assert clamp is None or clamp >= 0
|
||||
spec = activation_funcs[act]
|
||||
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
||||
gain = float(gain if gain is not None else spec.def_gain)
|
||||
clamp = float(clamp if clamp is not None else -1)
|
||||
|
||||
# Add bias.
|
||||
if b is not None:
|
||||
assert isinstance(b, torch.Tensor) and b.ndim == 1
|
||||
assert 0 <= dim < x.ndim
|
||||
assert b.shape[0] == x.shape[dim]
|
||||
x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)])
|
||||
|
||||
# Evaluate activation function.
|
||||
alpha = float(alpha)
|
||||
x = spec.func(x, alpha=alpha)
|
||||
|
||||
# Scale by gain.
|
||||
gain = float(gain)
|
||||
if gain != 1:
|
||||
x = x * gain
|
||||
|
||||
# Clamp.
|
||||
if clamp >= 0:
|
||||
x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type
|
||||
return x
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
_bias_act_cuda_cache = dict()
|
||||
|
||||
def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None):
|
||||
"""Fast CUDA implementation of `bias_act()` using custom ops.
|
||||
"""
|
||||
# Parse arguments.
|
||||
assert clamp is None or clamp >= 0
|
||||
spec = activation_funcs[act]
|
||||
alpha = float(alpha if alpha is not None else spec.def_alpha)
|
||||
gain = float(gain if gain is not None else spec.def_gain)
|
||||
clamp = float(clamp if clamp is not None else -1)
|
||||
|
||||
# Lookup from cache.
|
||||
key = (dim, act, alpha, gain, clamp)
|
||||
if key in _bias_act_cuda_cache:
|
||||
return _bias_act_cuda_cache[key]
|
||||
|
||||
# Forward op.
|
||||
class BiasActCuda(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x, b): # pylint: disable=arguments-differ
|
||||
ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format
|
||||
x = x.contiguous(memory_format=ctx.memory_format)
|
||||
b = b.contiguous() if b is not None else _null_tensor
|
||||
y = x
|
||||
if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor:
|
||||
y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
ctx.save_for_backward(
|
||||
x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
||||
b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor,
|
||||
y if 'y' in spec.ref else _null_tensor)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dy): # pylint: disable=arguments-differ
|
||||
dy = dy.contiguous(memory_format=ctx.memory_format)
|
||||
x, b, y = ctx.saved_tensors
|
||||
dx = None
|
||||
db = None
|
||||
|
||||
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
|
||||
dx = dy
|
||||
if act != 'linear' or gain != 1 or clamp >= 0:
|
||||
dx = BiasActCudaGrad.apply(dy, x, b, y)
|
||||
|
||||
if ctx.needs_input_grad[1]:
|
||||
db = dx.sum([i for i in range(dx.ndim) if i != dim])
|
||||
|
||||
return dx, db
|
||||
|
||||
# Backward op.
|
||||
class BiasActCudaGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ
|
||||
ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format
|
||||
dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
ctx.save_for_backward(
|
||||
dy if spec.has_2nd_grad else _null_tensor,
|
||||
x, b, y)
|
||||
return dx
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, d_dx): # pylint: disable=arguments-differ
|
||||
d_dx = d_dx.contiguous(memory_format=ctx.memory_format)
|
||||
dy, x, b, y = ctx.saved_tensors
|
||||
d_dy = None
|
||||
d_x = None
|
||||
d_b = None
|
||||
d_y = None
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
d_dy = BiasActCudaGrad.apply(d_dx, x, b, y)
|
||||
|
||||
if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]):
|
||||
d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp)
|
||||
|
||||
if spec.has_2nd_grad and ctx.needs_input_grad[2]:
|
||||
d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim])
|
||||
|
||||
return d_dy, d_x, d_b, d_y
|
||||
|
||||
# Add to cache.
|
||||
_bias_act_cuda_cache[key] = BiasActCuda
|
||||
return BiasActCuda
|
||||
|
||||
#----------------------------------------------------------------------------
|
57
svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cpp
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
// CUDA forward declarations
|
||||
|
||||
namespace at {namespace native {
|
||||
std::vector<torch::Tensor> grid_sample2d_cuda_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners);
|
||||
std::vector<torch::Tensor> grid_sample3d_cuda_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners);
|
||||
}}
|
||||
|
||||
std::vector<torch::Tensor> grid_sample2d_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners) {
|
||||
|
||||
return at::native::grid_sample2d_cuda_grad2(grad2_grad_input, grad2_grad_grid,
|
||||
grad_output, input, grid, padding_mode, align_corners);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> grid_sample3d_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners) {
|
||||
|
||||
return at::native::grid_sample3d_cuda_grad2(grad2_grad_input, grad2_grad_grid,
|
||||
grad_output, input, grid, padding_mode, align_corners);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("grad2_2d", &grid_sample2d_grad2, "grid_sample2d second derivative");
|
||||
m.def("grad2_3d", &grid_sample3d_grad2, "grid_sample3d second derivative");
|
||||
}
|
||||
|
668
svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.cu
vendored
Normal file
@ -0,0 +1,668 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <ATen/cuda/detail/KernelUtils.h>
|
||||
#include <ATen/native/cuda/KernelUtils.cuh>
|
||||
#include <ATen/native/cuda/GridSampler.cuh>
|
||||
#include <ATen/native/cuda/UpSample.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/TensorInfo.cuh>
|
||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace at { namespace native {
|
||||
namespace {
|
||||
|
||||
using namespace at::cuda::detail;
|
||||
|
||||
using at::native::detail::GridSamplerInterpolation;
|
||||
using at::native::detail::GridSamplerPadding;
|
||||
|
||||
template <typename scalar_t, typename index_t>
|
||||
C10_LAUNCH_BOUNDS_1(256)
|
||||
__global__ void grid_sampler_2d_grad2_kernel(
|
||||
const index_t nthreads,
|
||||
TensorInfo<scalar_t, index_t> grad2_grad_input,
|
||||
TensorInfo<scalar_t, index_t> grad2_grad_grid,
|
||||
TensorInfo<scalar_t, index_t> grad_output,
|
||||
TensorInfo<scalar_t, index_t> input,
|
||||
TensorInfo<scalar_t, index_t> grid,
|
||||
TensorInfo<scalar_t, index_t> grad_grad_output,
|
||||
TensorInfo<scalar_t, index_t> grad_input,
|
||||
TensorInfo<scalar_t, index_t> grad_grid,
|
||||
const GridSamplerPadding padding_mode,
|
||||
bool align_corners,
|
||||
const index_t grad_input_memory_span) {
|
||||
|
||||
index_t C = input.sizes[1];
|
||||
index_t inp_H = input.sizes[2];
|
||||
index_t inp_W = input.sizes[3];
|
||||
|
||||
index_t out_H = grid.sizes[1];
|
||||
index_t out_W = grid.sizes[2];
|
||||
|
||||
index_t g2inp_sN = grad2_grad_input.strides[0];
|
||||
index_t g2inp_sC = grad2_grad_input.strides[1];
|
||||
index_t g2inp_sH = grad2_grad_input.strides[2];
|
||||
index_t g2inp_sW = grad2_grad_input.strides[3];
|
||||
|
||||
index_t g2grid_sN = grad2_grad_grid.strides[0];
|
||||
index_t g2grid_sH = grad2_grad_grid.strides[1];
|
||||
index_t g2grid_sW = grad2_grad_grid.strides[2];
|
||||
index_t g2grid_sCoor = grad2_grad_grid.strides[3];
|
||||
|
||||
index_t gOut_sN = grad_output.strides[0];
|
||||
index_t gOut_sC = grad_output.strides[1];
|
||||
index_t gOut_sH = grad_output.strides[2];
|
||||
index_t gOut_sW = grad_output.strides[3];
|
||||
|
||||
index_t inp_sN = input.strides[0];
|
||||
index_t inp_sC = input.strides[1];
|
||||
index_t inp_sH = input.strides[2];
|
||||
index_t inp_sW = input.strides[3];
|
||||
|
||||
index_t grid_sN = grid.strides[0];
|
||||
index_t grid_sH = grid.strides[1];
|
||||
index_t grid_sW = grid.strides[2];
|
||||
index_t grid_sCoor = grid.strides[3];
|
||||
|
||||
index_t gInp_sN = grad_input.strides[0];
|
||||
index_t gInp_sC = grad_input.strides[1];
|
||||
index_t gInp_sH = grad_input.strides[2];
|
||||
index_t gInp_sW = grad_input.strides[3];
|
||||
|
||||
index_t gGrid_sW = grad_grid.strides[2];
|
||||
|
||||
index_t ggOut_sN = grad_grad_output.strides[0];
|
||||
index_t ggOut_sC = grad_grad_output.strides[1];
|
||||
index_t ggOut_sH = grad_grad_output.strides[2];
|
||||
index_t ggOut_sW = grad_grad_output.strides[3];
|
||||
|
||||
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
|
||||
const index_t w = index % out_W;
|
||||
const index_t h = (index / out_W) % out_H;
|
||||
const index_t n = index / (out_H * out_W);
|
||||
|
||||
/* Grid related staff */
|
||||
index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
|
||||
|
||||
// get the corresponding input x, y co-ordinates from grid
|
||||
scalar_t x = grid.data[grid_offset];
|
||||
scalar_t y = grid.data[grid_offset + grid_sCoor];
|
||||
|
||||
// multipliers for gradients on ix and iy
|
||||
scalar_t gix_mult, giy_mult;
|
||||
scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
|
||||
scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);
|
||||
|
||||
// get NE, NW, SE, SW pixel values from (x, y)
|
||||
index_t ix_nw = static_cast<index_t>(::floor(ix));
|
||||
index_t iy_nw = static_cast<index_t>(::floor(iy));
|
||||
index_t ix_ne = ix_nw + 1;
|
||||
index_t iy_ne = iy_nw;
|
||||
index_t ix_sw = ix_nw;
|
||||
index_t iy_sw = iy_nw + 1;
|
||||
index_t ix_se = ix_nw + 1;
|
||||
index_t iy_se = iy_nw + 1;
|
||||
|
||||
// get surfaces to each neighbor:
|
||||
scalar_t nw = (ix_se - ix) * (iy_se - iy);
|
||||
scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
/* grad2_grad_input related init */
|
||||
scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN;
|
||||
|
||||
/* grad2_grad_grid related init */
|
||||
grid_offset = n * g2grid_sN + h * g2grid_sH + w * g2grid_sW;
|
||||
scalar_t dx = grad2_grad_grid.data[grid_offset];
|
||||
scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor];
|
||||
|
||||
dx = dx * gix_mult;
|
||||
dy = dy * giy_mult;
|
||||
|
||||
/* grad_output related init */
|
||||
scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
|
||||
|
||||
/* input related init */
|
||||
scalar_t *inp_ptr_NC = input.data + n * inp_sN;
|
||||
|
||||
/* grad_grad_output related init */
|
||||
scalar_t *ggOut_ptr_NCHW = grad_grad_output.data + n * ggOut_sN + h * ggOut_sH + w * ggOut_sW;
|
||||
|
||||
/* grad_input related init */
|
||||
index_t NC_offset = n * gInp_sN;
|
||||
|
||||
/* grad_grid related init */
|
||||
scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
|
||||
scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0);
|
||||
|
||||
scalar_t nw_val, ne_val, sw_val, se_val;
|
||||
scalar_t g2_nw_val, g2_ne_val, g2_sw_val, g2_se_val;
|
||||
|
||||
scalar_t zero = static_cast<scalar_t>(0);
|
||||
for (index_t c = 0; c < C;
|
||||
++c,
|
||||
g2_inp_ptr_NC += g2inp_sC,
|
||||
inp_ptr_NC += inp_sC,
|
||||
NC_offset += gInp_sC,
|
||||
gOut_ptr_NCHW += gOut_sC,
|
||||
ggOut_ptr_NCHW += ggOut_sC) {
|
||||
|
||||
nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]: zero;
|
||||
ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]: zero;
|
||||
sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]: zero;
|
||||
se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]: zero;
|
||||
|
||||
g2_nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? g2_inp_ptr_NC[iy_nw * g2inp_sH + ix_nw * g2inp_sW]: zero;
|
||||
g2_ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? g2_inp_ptr_NC[iy_ne * g2inp_sH + ix_ne * g2inp_sW]: zero;
|
||||
g2_sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? g2_inp_ptr_NC[iy_sw * g2inp_sH + ix_sw * g2inp_sW]: zero;
|
||||
g2_se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? g2_inp_ptr_NC[iy_se * g2inp_sH + ix_se * g2inp_sW]: zero;
|
||||
|
||||
// Computing gradient wrt to grad_output = grad2_grad_input * x * y + grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val
|
||||
// grad2_grad_input * x * y
|
||||
*ggOut_ptr_NCHW = static_cast<scalar_t>(0);
|
||||
*ggOut_ptr_NCHW += g2_nw_val * nw + g2_ne_val * ne + g2_sw_val * sw + g2_se_val * se;
|
||||
|
||||
scalar_t nw_tmp = -dx * (iy_se - iy) - dy * (ix_se - ix);
|
||||
scalar_t ne_tmp = +dx * (iy_sw - iy) - dy * (ix - ix_sw);
|
||||
scalar_t sw_tmp = -dx * (iy - iy_ne) + dy * (ix_ne - ix);
|
||||
scalar_t se_tmp = +dx * (iy - iy_nw) + dy * (ix - ix_nw);
|
||||
|
||||
|
||||
// grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val
|
||||
*ggOut_ptr_NCHW += nw_val * nw_tmp + ne_tmp * ne_val + sw_tmp * sw_val + se_tmp * se_val;
|
||||
|
||||
// Computing gradient wrt input = grad2_grad_grid_x * grad_output * y + grad2_grad_grid_y * grad_output * x
|
||||
scalar_t gOut = *gOut_ptr_NCHW;
|
||||
//scalar_t val;
|
||||
//val = gOut * (-dx * (iy_se - iy) - dy * (ix_se - ix));
|
||||
safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw_tmp * gOut, NC_offset, grad_input_memory_span);
|
||||
//val = gOut * (+dx * (iy_sw - iy) - dy * (ix - ix_sw));
|
||||
safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne_tmp * gOut, NC_offset, grad_input_memory_span);
|
||||
//val = gOut * (-dx * (iy - iy_ne) + dy * (ix_ne - ix));
|
||||
safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw_tmp * gOut, NC_offset, grad_input_memory_span);
|
||||
//val = gOut * (+dx * (iy - iy_nw) + dy * (ix - ix_nw));
|
||||
safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se_tmp * gOut, NC_offset, grad_input_memory_span);
|
||||
|
||||
scalar_t dxy = nw_val - ne_val - sw_val + se_val;
|
||||
// Computing gradient wrt grid_x = grad2_grad_input * y * gOut + grad2_grad_grid_y * val * gOut
|
||||
gix += gOut * (-g2_nw_val * (iy_se - iy) + g2_ne_val * (iy_sw - iy)
|
||||
-g2_sw_val * (iy - iy_ne) + g2_se_val * (iy - iy_nw));
|
||||
gix += gOut * dy * dxy;
|
||||
|
||||
// Computing gradient wrt grid_y = grad2_grad_input * x * gOut + grad2_grad_grid_x * val * gOut
|
||||
giy += gOut * (-g2_nw_val * (ix_se - ix) - g2_ne_val * (ix - ix_sw)
|
||||
+g2_sw_val * (ix_ne - ix) + g2_se_val * (ix - ix_nw));
|
||||
giy += gOut * dx * dxy;
|
||||
}
|
||||
|
||||
gGrid_ptr_NHW[0] = gix * gix_mult;
|
||||
gGrid_ptr_NHW[1] = giy * giy_mult;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t>
|
||||
C10_LAUNCH_BOUNDS_1(256)
|
||||
__global__ void grid_sampler_3d_grad2_kernel(
|
||||
const index_t nthreads,
|
||||
TensorInfo<scalar_t, index_t> grad2_grad_input,
|
||||
TensorInfo<scalar_t, index_t> grad2_grad_grid,
|
||||
TensorInfo<scalar_t, index_t> grad_output,
|
||||
TensorInfo<scalar_t, index_t> input,
|
||||
TensorInfo<scalar_t, index_t> grid,
|
||||
TensorInfo<scalar_t, index_t> grad_grad_output,
|
||||
TensorInfo<scalar_t, index_t> grad_input,
|
||||
TensorInfo<scalar_t, index_t> grad_grid,
|
||||
const GridSamplerPadding padding_mode,
|
||||
bool align_corners,
|
||||
const index_t grad_input_memory_span) {
|
||||
|
||||
index_t C = input.sizes[1];
|
||||
index_t inp_D = input.sizes[2];
|
||||
index_t inp_H = input.sizes[3];
|
||||
index_t inp_W = input.sizes[4];
|
||||
|
||||
index_t out_D = grid.sizes[1];
|
||||
index_t out_H = grid.sizes[2];
|
||||
index_t out_W = grid.sizes[3];
|
||||
|
||||
index_t g2inp_sN = grad2_grad_input.strides[0];
|
||||
index_t g2inp_sC = grad2_grad_input.strides[1];
|
||||
index_t g2inp_sD = grad2_grad_input.strides[2];
|
||||
index_t g2inp_sH = grad2_grad_input.strides[3];
|
||||
index_t g2inp_sW = grad2_grad_input.strides[4];
|
||||
|
||||
index_t g2grid_sN = grad2_grad_grid.strides[0];
|
||||
index_t g2grid_sD = grad2_grad_grid.strides[1];
|
||||
index_t g2grid_sH = grad2_grad_grid.strides[2];
|
||||
index_t g2grid_sW = grad2_grad_grid.strides[3];
|
||||
index_t g2grid_sCoor = grad2_grad_grid.strides[4];
|
||||
|
||||
index_t gOut_sN = grad_output.strides[0];
|
||||
index_t gOut_sC = grad_output.strides[1];
|
||||
index_t gOut_sD = grad_output.strides[2];
|
||||
index_t gOut_sH = grad_output.strides[3];
|
||||
index_t gOut_sW = grad_output.strides[4];
|
||||
|
||||
index_t inp_sN = input.strides[0];
|
||||
index_t inp_sC = input.strides[1];
|
||||
index_t inp_sD = input.strides[2];
|
||||
index_t inp_sH = input.strides[3];
|
||||
index_t inp_sW = input.strides[4];
|
||||
|
||||
index_t grid_sN = grid.strides[0];
|
||||
index_t grid_sD = grid.strides[1];
|
||||
index_t grid_sH = grid.strides[2];
|
||||
index_t grid_sW = grid.strides[3];
|
||||
index_t grid_sCoor = grid.strides[4];
|
||||
|
||||
index_t gInp_sN = grad_input.strides[0];
|
||||
index_t gInp_sC = grad_input.strides[1];
|
||||
index_t gInp_sD = grad_input.strides[2];
|
||||
index_t gInp_sH = grad_input.strides[3];
|
||||
index_t gInp_sW = grad_input.strides[4];
|
||||
|
||||
index_t gGrid_sW = grad_grid.strides[3];
|
||||
|
||||
index_t ggOut_sN = grad_grad_output.strides[0];
|
||||
index_t ggOut_sC = grad_grad_output.strides[1];
|
||||
index_t ggOut_sD = grad_grad_output.strides[2];
|
||||
index_t ggOut_sH = grad_grad_output.strides[3];
|
||||
index_t ggOut_sW = grad_grad_output.strides[4];
|
||||
|
||||
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
|
||||
const index_t w = index % out_W;
|
||||
const index_t h = (index / out_W) % out_H;
|
||||
const index_t d = (index / (out_H * out_W)) % out_D;
|
||||
const index_t n = index / (out_D * out_H * out_W);
|
||||
|
||||
/* Grid related staff */
|
||||
index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
|
||||
|
||||
// get the corresponding input x, y co-ordinates from grid
|
||||
scalar_t ix = grid.data[grid_offset];
|
||||
scalar_t iy = grid.data[grid_offset + grid_sCoor];
|
||||
scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
|
||||
|
||||
// multipliers for gradients on ix and iy
|
||||
scalar_t gix_mult, giy_mult, giz_mult;
|
||||
ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
|
||||
iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
|
||||
iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
|
||||
|
||||
// get NE, NW, SE, SW pixel values from (x, y)
|
||||
index_t ix_tnw = static_cast<index_t>(::floor(ix));
|
||||
index_t iy_tnw = static_cast<index_t>(::floor(iy));
|
||||
index_t iz_tnw = static_cast<index_t>(::floor(iz));
|
||||
|
||||
index_t ix_tne = ix_tnw + 1;
|
||||
index_t iy_tne = iy_tnw;
|
||||
index_t iz_tne = iz_tnw;
|
||||
|
||||
index_t ix_tsw = ix_tnw;
|
||||
index_t iy_tsw = iy_tnw + 1;
|
||||
index_t iz_tsw = iz_tnw;
|
||||
|
||||
index_t ix_tse = ix_tnw + 1;
|
||||
index_t iy_tse = iy_tnw + 1;
|
||||
index_t iz_tse = iz_tnw;
|
||||
|
||||
index_t ix_bnw = ix_tnw;
|
||||
index_t iy_bnw = iy_tnw;
|
||||
index_t iz_bnw = iz_tnw + 1;
|
||||
|
||||
index_t ix_bne = ix_tnw + 1;
|
||||
index_t iy_bne = iy_tnw;
|
||||
index_t iz_bne = iz_tnw + 1;
|
||||
|
||||
index_t ix_bsw = ix_tnw;
|
||||
index_t iy_bsw = iy_tnw + 1;
|
||||
index_t iz_bsw = iz_tnw + 1;
|
||||
|
||||
index_t ix_bse = ix_tnw + 1;
|
||||
index_t iy_bse = iy_tnw + 1;
|
||||
index_t iz_bse = iz_tnw + 1;
|
||||
|
||||
// get surfaces to each neighbor:
|
||||
scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
|
||||
scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
|
||||
scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
|
||||
scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
|
||||
scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
|
||||
scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
|
||||
scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
|
||||
scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
|
||||
|
||||
/* grad2_grad_input related init */
|
||||
scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN;
|
||||
|
||||
/* grad2_grad_grid related init */
|
||||
grid_offset = n * g2grid_sN + d * g2grid_sD + h * g2grid_sH + w * g2grid_sW;
|
||||
scalar_t dx = grad2_grad_grid.data[grid_offset];
|
||||
scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor];
|
||||
scalar_t dz = grad2_grad_grid.data[grid_offset + 2 * g2grid_sCoor];
|
||||
|
||||
dx = dx * gix_mult;
|
||||
dy = dy * giy_mult;
|
||||
dz = dz * giz_mult;
|
||||
|
||||
/* grad_output related init */
|
||||
scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
|
||||
|
||||
/* input related init */
|
||||
scalar_t *inp_ptr_NC = input.data + n * inp_sN;
|
||||
|
||||
/* grad_grad_output related init */
|
||||
scalar_t *ggOut_ptr_NCDHW = grad_grad_output.data + n * ggOut_sN + d * ggOut_sD + h * ggOut_sH + w * ggOut_sW;
|
||||
|
||||
/* grad_input related init */
|
||||
index_t NC_offset = n * gInp_sN;
|
||||
|
||||
/* grad_grid related init */
|
||||
scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
|
||||
scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0);
|
||||
|
||||
scalar_t tnw_val, tne_val, tsw_val, tse_val, bnw_val, bne_val, bsw_val, bse_val;
|
||||
scalar_t g2_tnw_val, g2_tne_val, g2_tsw_val, g2_tse_val, g2_bnw_val, g2_bne_val, g2_bsw_val, g2_bse_val;
|
||||
|
||||
scalar_t zero = static_cast<scalar_t>(0);
|
||||
for (index_t c = 0; c < C;
|
||||
++c,
|
||||
g2_inp_ptr_NC += g2inp_sC,
|
||||
inp_ptr_NC += inp_sC,
|
||||
NC_offset += gInp_sC,
|
||||
gOut_ptr_NCDHW += gOut_sC,
|
||||
ggOut_ptr_NCDHW += ggOut_sC) {
|
||||
|
||||
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
|
||||
tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
|
||||
g2_tnw_val = g2_inp_ptr_NC[iz_tnw * g2inp_sD + iy_tnw * g2inp_sH + ix_tnw * g2inp_sW];
|
||||
} else {
|
||||
tnw_val = zero;
|
||||
g2_tnw_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
|
||||
tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
|
||||
g2_tne_val = g2_inp_ptr_NC[iz_tne * g2inp_sD + iy_tne * g2inp_sH + ix_tne * g2inp_sW];
|
||||
} else {
|
||||
tne_val = zero;
|
||||
g2_tne_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
|
||||
tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
|
||||
g2_tsw_val = g2_inp_ptr_NC[iz_tsw * g2inp_sD + iy_tsw * g2inp_sH + ix_tsw * g2inp_sW];
|
||||
} else {
|
||||
tsw_val = zero;
|
||||
g2_tsw_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
|
||||
tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
|
||||
g2_tse_val = g2_inp_ptr_NC[iz_tse * g2inp_sD + iy_tse * g2inp_sH + ix_tse * g2inp_sW];
|
||||
} else {
|
||||
tse_val = zero;
|
||||
g2_tse_val = zero;
|
||||
}
|
||||
|
||||
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
|
||||
bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
|
||||
g2_bnw_val = g2_inp_ptr_NC[iz_bnw * g2inp_sD + iy_bnw * g2inp_sH + ix_bnw * g2inp_sW];
|
||||
} else {
|
||||
bnw_val = zero;
|
||||
g2_bnw_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
|
||||
bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
|
||||
g2_bne_val = g2_inp_ptr_NC[iz_bne * g2inp_sD + iy_bne * g2inp_sH + ix_bne * g2inp_sW];
|
||||
} else {
|
||||
bne_val = zero;
|
||||
g2_bne_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
|
||||
bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
|
||||
g2_bsw_val = g2_inp_ptr_NC[iz_bsw * g2inp_sD + iy_bsw * g2inp_sH + ix_bsw * g2inp_sW];
|
||||
} else {
|
||||
bsw_val = zero;
|
||||
g2_bsw_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
|
||||
bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
|
||||
g2_bse_val = g2_inp_ptr_NC[iz_bse * g2inp_sD + iy_bse * g2inp_sH + ix_bse * g2inp_sW];
|
||||
} else {
|
||||
bse_val = zero;
|
||||
g2_bse_val = zero;
|
||||
}
|
||||
|
||||
// Computing gradient wrt to grad_output =
|
||||
// grad2_grad_input * x * y * z
|
||||
*ggOut_ptr_NCDHW = static_cast<scalar_t>(0);
|
||||
*ggOut_ptr_NCDHW += g2_tnw_val * tnw + g2_tne_val * tne + g2_tsw_val * tsw + g2_tse_val * tse
|
||||
+g2_bnw_val * bnw + g2_bne_val * bne + g2_bsw_val * bsw + g2_bse_val * bse;
|
||||
|
||||
// +val * (grad2_grad_grid_x * y * z + grad2_grad_grid_y * x * z + grad2_grad_grid_z * x * y)
|
||||
scalar_t tnw_tmp = (-dx * (iy_bse - iy) * (iz_bse - iz) - dy * (ix_bse - ix) * (iz_bse - iz) - dz * (ix_bse - ix) * (iy_bse - iy));
|
||||
scalar_t tne_tmp = (+dx * (iy_bsw - iy) * (iz_bsw - iz) - dy * (ix - ix_bsw) * (iz_bsw - iz) - dz * (ix - ix_bsw) * (iy_bsw - iy));
|
||||
scalar_t tsw_tmp = (-dx * (iy - iy_bne) * (iz_bne - iz) + dy * (ix_bne - ix) * (iz_bne - iz) - dz * (ix_bne - ix) * (iy - iy_bne));
|
||||
scalar_t tse_tmp = (+dx * (iy - iy_bnw) * (iz_bnw - iz) + dy * (ix - ix_bnw) * (iz_bnw - iz) - dz * (ix - ix_bnw) * (iy - iy_bnw));
|
||||
scalar_t bnw_tmp = (-dx * (iy_tse - iy) * (iz - iz_tse) - dy * (ix_tse - ix) * (iz - iz_tse) + dz * (ix_tse - ix) * (iy_tse - iy));
|
||||
scalar_t bne_tmp = (+dx * (iy_tsw - iy) * (iz - iz_tsw) - dy * (ix - ix_tsw) * (iz - iz_tsw) + dz * (ix - ix_tsw) * (iy_tsw - iy));
|
||||
scalar_t bsw_tmp = (-dx * (iy - iy_tne) * (iz - iz_tne) + dy * (ix_tne - ix) * (iz - iz_tne) + dz * (ix_tne - ix) * (iy - iy_tne));
|
||||
scalar_t bse_tmp = (+dx * (iy - iy_tnw) * (iz - iz_tnw) + dy * (ix - ix_tnw) * (iz - iz_tnw) + dz * (ix - ix_tnw) * (iy - iy_tnw));
|
||||
|
||||
*ggOut_ptr_NCDHW += tnw_val * tnw_tmp + tne_val * tne_tmp + tsw_val * tsw_tmp + tse_val * tse_tmp
|
||||
+bnw_val * bnw_tmp + bne_val * bne_tmp + bsw_val * bsw_tmp + bse_val * bse_tmp;
|
||||
|
||||
// Computing gradient wrt input = grad2_grad_grid_x * grad_output * y * z + grad2_grad_grid_y * grad_output * x * z +
|
||||
// grad2_grad_grid_z * grad_output * y * z
|
||||
scalar_t gOut = *gOut_ptr_NCDHW;
|
||||
|
||||
safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
|
||||
//Computing gradient wrt grid
|
||||
scalar_t dxy = (tnw_val * (iz_bse - iz) - tne_val * (iz_bsw - iz)
|
||||
-tsw_val * (iz_bne - iz) + tse_val * (iz_bnw - iz)
|
||||
+bnw_val * (iz - iz_tse) - bne_val * (iz - iz_tsw)
|
||||
-bsw_val * (iz - iz_tne) + bse_val * (iz - iz_tnw));
|
||||
|
||||
scalar_t dxz = (tnw_val * (iy_bse - iy) - tne_val * (iy_bsw - iy)
|
||||
+tsw_val * (iy - iy_bne) - tse_val * (iy - iy_bnw)
|
||||
-bnw_val * (iy_tse - iy) + bne_val * (iy_tsw - iy)
|
||||
-bsw_val * (iy - iy_tne) + bse_val * (iy - iy_tnw));
|
||||
|
||||
scalar_t dyz = (tnw_val * (ix_bse - ix) + tne_val * (ix - ix_bsw)
|
||||
-tsw_val * (ix_bne - ix) - tse_val * (ix - ix_bnw)
|
||||
-bnw_val * (ix_tse - ix) - bne_val * (ix - ix_tsw)
|
||||
+bsw_val * (ix_tne - ix) + bse_val * (ix - ix_tnw));
|
||||
|
||||
|
||||
// Computing gradient wrt grid_x =
|
||||
// grad2_grad_input * z * y * gOut
|
||||
gix += gOut * (-g2_tnw_val * (iy_bse - iy) * (iz_bse - iz) + g2_tne_val * (iy_bsw - iy) * (iz_bsw - iz)
|
||||
-g2_tsw_val * (iy - iy_bne) * (iz_bne - iz) + g2_tse_val * (iy - iy_bnw) * (iz_bnw - iz)
|
||||
-g2_bnw_val * (iy_tse - iy) * (iz - iz_tse) + g2_bne_val * (iy_tsw - iy) * (iz - iz_tsw)
|
||||
-g2_bsw_val * (iy - iy_tne) * (iz - iz_tne) + g2_bse_val * (iy - iy_tnw) * (iz - iz_tnw));
|
||||
|
||||
//+ grad2_grad_grid_z * y * val * gOut + grad2_grad_grid_y * z * val * gOut
|
||||
gix += gOut * (dz * dxz + dy * dxy);
|
||||
|
||||
// Computing gradient wrt grid_y =
|
||||
// grad2_grad_input * x * z * gOut
|
||||
giy += gOut * (-g2_tnw_val * (ix_bse - ix) * (iz_bse - iz) - g2_tne_val * (ix - ix_bsw) * (iz_bsw - iz)
|
||||
+g2_tsw_val * (ix_bne - ix) * (iz_bne - iz) + g2_tse_val * (ix - ix_bnw) * (iz_bnw - iz)
|
||||
-g2_bnw_val * (ix_tse - ix) * (iz - iz_tse) - g2_bne_val * (ix - ix_tsw) * (iz - iz_tsw)
|
||||
+g2_bsw_val * (ix_tne - ix) * (iz - iz_tne) + g2_bse_val * (ix - ix_tnw) * (iz - iz_tnw));
|
||||
//+ grad2_grad_grid_x * z * val * gOut + grad2_grad_grid_z * x * val * gOut
|
||||
giy += gOut * (dx * dxy + dz * dyz);
|
||||
|
||||
// Computing gradient wrt grid_z =
|
||||
// grad2_grad_input * x * y * gOut
|
||||
giz += gOut * (-g2_tnw_val * (ix_bse - ix) * (iy_bse - iy) - g2_tne_val * (ix - ix_bsw) * (iy_bsw - iy)
|
||||
-g2_tsw_val * (ix_bne - ix) * (iy - iy_bne) - g2_tse_val * (ix - ix_bnw) * (iy - iy_bnw)
|
||||
+g2_bnw_val * (ix_tse - ix) * (iy_tse - iy) + g2_bne_val * (ix - ix_tsw) * (iy_tsw - iy)
|
||||
+g2_bsw_val * (ix_tne - ix) * (iy - iy_tne) + g2_bse_val * (ix - ix_tnw) * (iy - iy_tnw));
|
||||
//+ grad2_grad_grid_x * y * val * gOut + grad2_grad_grid_y * x * val * gOut
|
||||
giz += gOut * (dx * dxz + dy * dyz);
|
||||
}
|
||||
|
||||
gGrid_ptr_NDHW[0] = gix * gix_mult;
|
||||
gGrid_ptr_NDHW[1] = giy * giy_mult;
|
||||
gGrid_ptr_NDHW[2] = giz * giz_mult;
|
||||
}
|
||||
}}
|
||||
|
||||
|
||||
std::vector<torch::Tensor> grid_sample2d_cuda_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners) {
|
||||
|
||||
const auto batch_size = input.size(0);
|
||||
const auto C = input.size(1);
|
||||
const auto H_IN = input.size(2);
|
||||
const auto W_IN = input.size(3);
|
||||
|
||||
const auto H_OUT = grid.size(1);
|
||||
const auto W_OUT = grid.size(2);
|
||||
|
||||
torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
|
||||
int64_t count = batch_size * H_OUT * W_OUT;
|
||||
|
||||
if (count > 0) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_grad2_cuda", [&] {
|
||||
if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
|
||||
canUse32BitIndexMath(grad_output)) {
|
||||
grid_sampler_2d_grad2_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
static_cast<int>(count),
|
||||
getTensorInfo<scalar_t, int>(grad2_grad_input),
|
||||
getTensorInfo<scalar_t, int>(grad2_grad_grid),
|
||||
getTensorInfo<scalar_t, int>(grad_output),
|
||||
getTensorInfo<scalar_t, int>(input),
|
||||
getTensorInfo<scalar_t, int>(grid),
|
||||
getTensorInfo<scalar_t, int>(grad_grad_output),
|
||||
getTensorInfo<scalar_t, int>(grad_input),
|
||||
getTensorInfo<scalar_t, int>(grad_grid),
|
||||
static_cast<GridSamplerPadding>(padding_mode),
|
||||
align_corners,
|
||||
static_cast<int>(grad_input.numel()));
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
grid_sampler_2d_grad2_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
getTensorInfo<scalar_t, int64_t>(grad2_grad_input),
|
||||
getTensorInfo<scalar_t, int64_t>(grad2_grad_grid),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_output),
|
||||
getTensorInfo<scalar_t, int64_t>(input),
|
||||
getTensorInfo<scalar_t, int64_t>(grid),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_grad_output),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_input),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_grid),
|
||||
static_cast<GridSamplerPadding>(padding_mode),
|
||||
align_corners,
|
||||
grad_input.numel());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return {grad_grad_output, grad_input, grad_grid};
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> grid_sample3d_cuda_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners) {
|
||||
|
||||
const auto batch_size = input.size(0);
|
||||
const auto C = input.size(1);
|
||||
const auto D_IN = input.size(2);
|
||||
const auto H_IN = input.size(3);
|
||||
const auto W_IN = input.size(4);
|
||||
|
||||
const auto D_OUT = grid.size(1);
|
||||
const auto H_OUT = grid.size(2);
|
||||
const auto W_OUT = grid.size(3);
|
||||
|
||||
torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
|
||||
int64_t count = batch_size * D_OUT * H_OUT * W_OUT;
|
||||
|
||||
if (count > 0) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_grad2_cuda", [&] {
|
||||
if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
|
||||
canUse32BitIndexMath(grad_output)) {
|
||||
grid_sampler_3d_grad2_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
static_cast<int>(count),
|
||||
getTensorInfo<scalar_t, int>(grad2_grad_input),
|
||||
getTensorInfo<scalar_t, int>(grad2_grad_grid),
|
||||
getTensorInfo<scalar_t, int>(grad_output),
|
||||
getTensorInfo<scalar_t, int>(input),
|
||||
getTensorInfo<scalar_t, int>(grid),
|
||||
getTensorInfo<scalar_t, int>(grad_grad_output),
|
||||
getTensorInfo<scalar_t, int>(grad_input),
|
||||
getTensorInfo<scalar_t, int>(grad_grid),
|
||||
static_cast<GridSamplerPadding>(padding_mode),
|
||||
align_corners,
|
||||
static_cast<int>(grad_input.numel()));
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
grid_sampler_3d_grad2_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
getTensorInfo<scalar_t, int64_t>(grad2_grad_input),
|
||||
getTensorInfo<scalar_t, int64_t>(grad2_grad_grid),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_output),
|
||||
getTensorInfo<scalar_t, int64_t>(input),
|
||||
getTensorInfo<scalar_t, int64_t>(grid),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_grad_output),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_input),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_grid),
|
||||
static_cast<GridSamplerPadding>(padding_mode),
|
||||
align_corners,
|
||||
grad_input.numel());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return {grad_grad_output, grad_input, grad_grid};
|
||||
}
|
||||
|
||||
}}
|
145
svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample.py
vendored
Normal file
@ -0,0 +1,145 @@
|
||||
from torch.utils.cpp_extension import load
|
||||
import torch
|
||||
from pkg_resources import parse_version
|
||||
|
||||
from .. import custom_ops
|
||||
import os
|
||||
|
||||
# gridsample_grad2 = load(name='gridsample_grad2', sources=['third_party/ops/gridsample_cuda.cpp', 'third_party/ops/gridsample_cuda.cu'], verbose=True)
|
||||
|
||||
gridsample_grad2 = load(name='gridsample_grad2', sources=[os.path.join(os.path.dirname(__file__), f) for f in ['grid_sample.cpp', 'grid_sample.cu']], verbose=True)
|
||||
|
||||
gridsample_grad2 = None
|
||||
|
||||
def _init():
|
||||
global gridsample_grad2
|
||||
if gridsample_grad2 is None:
|
||||
gridsample_grad2 = custom_ops.get_plugin(
|
||||
module_name='gridsample_grad2',
|
||||
sources=['gridsample_cuda.cpp', 'gridsample_cuda.cu'],
|
||||
headers=None,
|
||||
source_dir=os.path.dirname(__file__),
|
||||
extra_cuda_cflags=['--use_fast_math'],
|
||||
)
|
||||
return True
|
||||
|
||||
_init()
|
||||
|
||||
def grid_sample_2d(input, grid, padding_mode='zeros', align_corners=True):
|
||||
assert padding_mode in ['zeros', 'border']
|
||||
return _GridSample2dForward.apply(input, grid, padding_mode, align_corners)
|
||||
|
||||
|
||||
def grid_sample_3d(input, grid, padding_mode='zeros', align_corners=True):
|
||||
assert padding_mode in ['zeros', 'border']
|
||||
return _GridSample3dForward.apply(input, grid, padding_mode, align_corners)
|
||||
|
||||
|
||||
_use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a')
|
||||
|
||||
|
||||
class _GridSample2dForward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, grid, padding_mode=0, align_corners=True):
|
||||
assert input.ndim == 4
|
||||
assert grid.ndim == 4
|
||||
assert input.shape[0] == grid.shape[0]
|
||||
assert grid.shape[3] == 2
|
||||
|
||||
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear',
|
||||
padding_mode=padding_mode, align_corners=align_corners)
|
||||
ctx.save_for_backward(input, grid)
|
||||
ctx.padding_mode = ['zeros', 'border'].index(padding_mode)
|
||||
ctx.align_corners = align_corners
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, grid = ctx.saved_tensors
|
||||
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid, ctx.padding_mode, ctx.align_corners)
|
||||
return grad_input, grad_grid, None, None
|
||||
|
||||
class _GridSample2dBackward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input, grid, padding_mode=0, align_corners=True):
|
||||
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')[0]
|
||||
if _use_pytorch_1_11_api:
|
||||
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
|
||||
# breakpoint()
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners, output_mask)
|
||||
else:
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners)
|
||||
|
||||
ctx.save_for_backward(grad_output, input, grid)
|
||||
ctx.padding_mode = padding_mode
|
||||
ctx.align_corners = align_corners
|
||||
|
||||
return grad_input, grad_grid
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
||||
grad_output, input, grid = ctx.saved_tensors
|
||||
assert grad_output.is_cuda and input.is_cuda and grid.is_cuda and grad2_grad_input.is_cuda and grad2_grad_grid.is_cuda
|
||||
out = gridsample_grad2.grad2_2d(grad2_grad_input, grad2_grad_grid, grad_output,
|
||||
input, grid, ctx.padding_mode, ctx.align_corners)
|
||||
|
||||
grad_grad_output = out[0]
|
||||
grad_input = out[1]
|
||||
grad_grid = out[2]
|
||||
|
||||
return grad_grad_output, grad_input, grad_grid, None, None
|
||||
|
||||
|
||||
class _GridSample3dForward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, grid, padding_mode=0, align_corners=True):
|
||||
assert input.ndim == 5
|
||||
assert grid.ndim == 5
|
||||
assert input.shape[0] == grid.shape[0]
|
||||
assert grid.shape[4] == 3
|
||||
|
||||
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear',
|
||||
padding_mode=padding_mode, align_corners=align_corners)
|
||||
ctx.save_for_backward(input, grid)
|
||||
ctx.padding_mode = ['zeros', 'border'].index(padding_mode)
|
||||
ctx.align_corners = align_corners
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, grid = ctx.saved_tensors
|
||||
grad_input, grad_grid = _GridSample3dBackward.apply(grad_output, input, grid, ctx.padding_mode, ctx.align_corners)
|
||||
return grad_input, grad_grid, None, None
|
||||
|
||||
|
||||
|
||||
class _GridSample3dBackward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input, grid, padding_mode=0, align_corners=True):
|
||||
op = torch._C._jit_get_operation('aten::grid_sampler_3d_backward')
|
||||
if _use_pytorch_1_11_api:
|
||||
output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2])
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners, output_mask)
|
||||
else:
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, padding_mode, align_corners)
|
||||
|
||||
ctx.save_for_backward(grad_output, input, grid)
|
||||
ctx.padding_mode = padding_mode
|
||||
ctx.align_corners = align_corners
|
||||
|
||||
return grad_input, grad_grid
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
||||
grad_output, input, grid = ctx.saved_tensors
|
||||
assert grad_output.is_cuda and input.is_cuda and grid.is_cuda and grad2_grad_input.is_cuda and grad2_grad_grid.is_cuda
|
||||
out = gridsample_grad2.grad2_3d(grad2_grad_input, grad2_grad_grid, grad_output,
|
||||
input, grid, ctx.padding_mode, ctx.align_corners)
|
||||
|
||||
grad_grad_output = out[0]
|
||||
grad_input = out[1]
|
||||
grad_grid = out[2]
|
||||
|
||||
return grad_grad_output, grad_input, grad_grid, None, None
|
||||
|
||||
|
79
svrm/ldm/modules/rendering_neus/third_party/ops/grid_sample_gradfix.py
vendored
Normal file
@ -0,0 +1,79 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
|
||||
"""Custom replacement for `torch.nn.functional.grid_sample` that
|
||||
supports arbitrarily high order gradients between the input and output.
|
||||
Only works on 2D images and assumes
|
||||
`mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`."""
|
||||
|
||||
import torch
|
||||
|
||||
# pylint: disable=redefined-builtin
|
||||
# pylint: disable=arguments-differ
|
||||
# pylint: disable=protected-access
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
enabled = True # Enable the custom op by setting this to true.
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def grid_sample(input, grid):
|
||||
if _should_use_custom_op():
|
||||
return _GridSample2dForward.apply(input, grid)
|
||||
return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
def _should_use_custom_op():
|
||||
return enabled
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _GridSample2dForward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input, grid):
|
||||
assert input.ndim == 4
|
||||
assert grid.ndim == 4
|
||||
output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
ctx.save_for_backward(input, grid)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input, grid = ctx.saved_tensors
|
||||
grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid)
|
||||
return grad_input, grad_grid
|
||||
|
||||
#----------------------------------------------------------------------------
|
||||
|
||||
class _GridSample2dBackward(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, grad_output, input, grid):
|
||||
op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward')
|
||||
grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False)
|
||||
ctx.save_for_backward(grid)
|
||||
return grad_input, grad_grid
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad2_grad_input, grad2_grad_grid):
|
||||
_ = grad2_grad_grid # unused
|
||||
grid, = ctx.saved_tensors
|
||||
grad2_grad_output = None
|
||||
grad2_input = None # grad2_grad_input #
|
||||
grad2_grid = None # grad2_grad_grid #
|
||||
|
||||
if ctx.needs_input_grad[0]:
|
||||
grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid)
|
||||
|
||||
assert not ctx.needs_input_grad[2]
|
||||
return grad2_grad_output, grad2_input, grad2_grid
|
||||
|
||||
#----------------------------------------------------------------------------
|
57
svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cpp
vendored
Normal file
@ -0,0 +1,57 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
// CUDA forward declarations
|
||||
|
||||
namespace at {namespace native {
|
||||
std::vector<torch::Tensor> grid_sample2d_cuda_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners);
|
||||
std::vector<torch::Tensor> grid_sample3d_cuda_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners);
|
||||
}}
|
||||
|
||||
std::vector<torch::Tensor> grid_sample2d_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners) {
|
||||
|
||||
return at::native::grid_sample2d_cuda_grad2(grad2_grad_input, grad2_grad_grid,
|
||||
grad_output, input, grid, padding_mode, align_corners);
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> grid_sample3d_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners) {
|
||||
|
||||
return at::native::grid_sample3d_cuda_grad2(grad2_grad_input, grad2_grad_grid,
|
||||
grad_output, input, grid, padding_mode, align_corners);
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("grad2_2d", &grid_sample2d_grad2, "grid_sample2d second derivative");
|
||||
m.def("grad2_3d", &grid_sample3d_grad2, "grid_sample3d second derivative");
|
||||
}
|
||||
|
668
svrm/ldm/modules/rendering_neus/third_party/ops/gridsample_cuda.cu
vendored
Normal file
@ -0,0 +1,668 @@
|
||||
#include <torch/extension.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <ATen/cuda/detail/KernelUtils.h>
|
||||
#include <ATen/native/cuda/KernelUtils.cuh>
|
||||
#include <ATen/native/cuda/GridSampler.cuh>
|
||||
#include <ATen/native/cuda/UpSample.cuh>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/detail/TensorInfo.cuh>
|
||||
#include <ATen/cuda/detail/IndexUtils.cuh>
|
||||
#include <ATen/Dispatch.h>
|
||||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace at { namespace native {
|
||||
namespace {
|
||||
|
||||
using namespace at::cuda::detail;
|
||||
|
||||
using at::native::detail::GridSamplerInterpolation;
|
||||
using at::native::detail::GridSamplerPadding;
|
||||
|
||||
template <typename scalar_t, typename index_t>
|
||||
C10_LAUNCH_BOUNDS_1(256)
|
||||
__global__ void grid_sampler_2d_grad2_kernel(
|
||||
const index_t nthreads,
|
||||
TensorInfo<scalar_t, index_t> grad2_grad_input,
|
||||
TensorInfo<scalar_t, index_t> grad2_grad_grid,
|
||||
TensorInfo<scalar_t, index_t> grad_output,
|
||||
TensorInfo<scalar_t, index_t> input,
|
||||
TensorInfo<scalar_t, index_t> grid,
|
||||
TensorInfo<scalar_t, index_t> grad_grad_output,
|
||||
TensorInfo<scalar_t, index_t> grad_input,
|
||||
TensorInfo<scalar_t, index_t> grad_grid,
|
||||
const GridSamplerPadding padding_mode,
|
||||
bool align_corners,
|
||||
const index_t grad_input_memory_span) {
|
||||
|
||||
index_t C = input.sizes[1];
|
||||
index_t inp_H = input.sizes[2];
|
||||
index_t inp_W = input.sizes[3];
|
||||
|
||||
index_t out_H = grid.sizes[1];
|
||||
index_t out_W = grid.sizes[2];
|
||||
|
||||
index_t g2inp_sN = grad2_grad_input.strides[0];
|
||||
index_t g2inp_sC = grad2_grad_input.strides[1];
|
||||
index_t g2inp_sH = grad2_grad_input.strides[2];
|
||||
index_t g2inp_sW = grad2_grad_input.strides[3];
|
||||
|
||||
index_t g2grid_sN = grad2_grad_grid.strides[0];
|
||||
index_t g2grid_sH = grad2_grad_grid.strides[1];
|
||||
index_t g2grid_sW = grad2_grad_grid.strides[2];
|
||||
index_t g2grid_sCoor = grad2_grad_grid.strides[3];
|
||||
|
||||
index_t gOut_sN = grad_output.strides[0];
|
||||
index_t gOut_sC = grad_output.strides[1];
|
||||
index_t gOut_sH = grad_output.strides[2];
|
||||
index_t gOut_sW = grad_output.strides[3];
|
||||
|
||||
index_t inp_sN = input.strides[0];
|
||||
index_t inp_sC = input.strides[1];
|
||||
index_t inp_sH = input.strides[2];
|
||||
index_t inp_sW = input.strides[3];
|
||||
|
||||
index_t grid_sN = grid.strides[0];
|
||||
index_t grid_sH = grid.strides[1];
|
||||
index_t grid_sW = grid.strides[2];
|
||||
index_t grid_sCoor = grid.strides[3];
|
||||
|
||||
index_t gInp_sN = grad_input.strides[0];
|
||||
index_t gInp_sC = grad_input.strides[1];
|
||||
index_t gInp_sH = grad_input.strides[2];
|
||||
index_t gInp_sW = grad_input.strides[3];
|
||||
|
||||
index_t gGrid_sW = grad_grid.strides[2];
|
||||
|
||||
index_t ggOut_sN = grad_grad_output.strides[0];
|
||||
index_t ggOut_sC = grad_grad_output.strides[1];
|
||||
index_t ggOut_sH = grad_grad_output.strides[2];
|
||||
index_t ggOut_sW = grad_grad_output.strides[3];
|
||||
|
||||
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
|
||||
const index_t w = index % out_W;
|
||||
const index_t h = (index / out_W) % out_H;
|
||||
const index_t n = index / (out_H * out_W);
|
||||
|
||||
/* Grid related staff */
|
||||
index_t grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
|
||||
|
||||
// get the corresponding input x, y co-ordinates from grid
|
||||
scalar_t x = grid.data[grid_offset];
|
||||
scalar_t y = grid.data[grid_offset + grid_sCoor];
|
||||
|
||||
// multipliers for gradients on ix and iy
|
||||
scalar_t gix_mult, giy_mult;
|
||||
scalar_t ix = grid_sampler_compute_source_index_set_grad(x, inp_W, padding_mode, align_corners, &gix_mult);
|
||||
scalar_t iy = grid_sampler_compute_source_index_set_grad(y, inp_H, padding_mode, align_corners, &giy_mult);
|
||||
|
||||
// get NE, NW, SE, SW pixel values from (x, y)
|
||||
index_t ix_nw = static_cast<index_t>(::floor(ix));
|
||||
index_t iy_nw = static_cast<index_t>(::floor(iy));
|
||||
index_t ix_ne = ix_nw + 1;
|
||||
index_t iy_ne = iy_nw;
|
||||
index_t ix_sw = ix_nw;
|
||||
index_t iy_sw = iy_nw + 1;
|
||||
index_t ix_se = ix_nw + 1;
|
||||
index_t iy_se = iy_nw + 1;
|
||||
|
||||
// get surfaces to each neighbor:
|
||||
scalar_t nw = (ix_se - ix) * (iy_se - iy);
|
||||
scalar_t ne = (ix - ix_sw) * (iy_sw - iy);
|
||||
scalar_t sw = (ix_ne - ix) * (iy - iy_ne);
|
||||
scalar_t se = (ix - ix_nw) * (iy - iy_nw);
|
||||
|
||||
/* grad2_grad_input related init */
|
||||
scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN;
|
||||
|
||||
/* grad2_grad_grid related init */
|
||||
grid_offset = n * g2grid_sN + h * g2grid_sH + w * g2grid_sW;
|
||||
scalar_t dx = grad2_grad_grid.data[grid_offset];
|
||||
scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor];
|
||||
|
||||
dx = dx * gix_mult;
|
||||
dy = dy * giy_mult;
|
||||
|
||||
/* grad_output related init */
|
||||
scalar_t *gOut_ptr_NCHW = grad_output.data + n * gOut_sN + h * gOut_sH + w * gOut_sW;
|
||||
|
||||
/* input related init */
|
||||
scalar_t *inp_ptr_NC = input.data + n * inp_sN;
|
||||
|
||||
/* grad_grad_output related init */
|
||||
scalar_t *ggOut_ptr_NCHW = grad_grad_output.data + n * ggOut_sN + h * ggOut_sH + w * ggOut_sW;
|
||||
|
||||
/* grad_input related init */
|
||||
index_t NC_offset = n * gInp_sN;
|
||||
|
||||
/* grad_grid related init */
|
||||
scalar_t *gGrid_ptr_NHW = grad_grid.data + index * gGrid_sW;
|
||||
scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0);
|
||||
|
||||
scalar_t nw_val, ne_val, sw_val, se_val;
|
||||
scalar_t g2_nw_val, g2_ne_val, g2_sw_val, g2_se_val;
|
||||
|
||||
scalar_t zero = static_cast<scalar_t>(0);
|
||||
for (index_t c = 0; c < C;
|
||||
++c,
|
||||
g2_inp_ptr_NC += g2inp_sC,
|
||||
inp_ptr_NC += inp_sC,
|
||||
NC_offset += gInp_sC,
|
||||
gOut_ptr_NCHW += gOut_sC,
|
||||
ggOut_ptr_NCHW += ggOut_sC) {
|
||||
|
||||
nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? inp_ptr_NC[iy_nw * inp_sH + ix_nw * inp_sW]: zero;
|
||||
ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? inp_ptr_NC[iy_ne * inp_sH + ix_ne * inp_sW]: zero;
|
||||
sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? inp_ptr_NC[iy_sw * inp_sH + ix_sw * inp_sW]: zero;
|
||||
se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? inp_ptr_NC[iy_se * inp_sH + ix_se * inp_sW]: zero;
|
||||
|
||||
g2_nw_val = within_bounds_2d(iy_nw, ix_nw, inp_H, inp_W)? g2_inp_ptr_NC[iy_nw * g2inp_sH + ix_nw * g2inp_sW]: zero;
|
||||
g2_ne_val = within_bounds_2d(iy_ne, ix_ne, inp_H, inp_W)? g2_inp_ptr_NC[iy_ne * g2inp_sH + ix_ne * g2inp_sW]: zero;
|
||||
g2_sw_val = within_bounds_2d(iy_sw, ix_sw, inp_H, inp_W)? g2_inp_ptr_NC[iy_sw * g2inp_sH + ix_sw * g2inp_sW]: zero;
|
||||
g2_se_val = within_bounds_2d(iy_se, ix_se, inp_H, inp_W)? g2_inp_ptr_NC[iy_se * g2inp_sH + ix_se * g2inp_sW]: zero;
|
||||
|
||||
// Computing gradient wrt to grad_output = grad2_grad_input * x * y + grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val
|
||||
// grad2_grad_input * x * y
|
||||
*ggOut_ptr_NCHW = static_cast<scalar_t>(0);
|
||||
*ggOut_ptr_NCHW += g2_nw_val * nw + g2_ne_val * ne + g2_sw_val * sw + g2_se_val * se;
|
||||
|
||||
scalar_t nw_tmp = -dx * (iy_se - iy) - dy * (ix_se - ix);
|
||||
scalar_t ne_tmp = +dx * (iy_sw - iy) - dy * (ix - ix_sw);
|
||||
scalar_t sw_tmp = -dx * (iy - iy_ne) + dy * (ix_ne - ix);
|
||||
scalar_t se_tmp = +dx * (iy - iy_nw) + dy * (ix - ix_nw);
|
||||
|
||||
|
||||
// grad2_grad_grid_x * y * val + grad2_grad_grid_y * x * val
|
||||
*ggOut_ptr_NCHW += nw_val * nw_tmp + ne_tmp * ne_val + sw_tmp * sw_val + se_tmp * se_val;
|
||||
|
||||
// Computing gradient wrt input = grad2_grad_grid_x * grad_output * y + grad2_grad_grid_y * grad_output * x
|
||||
scalar_t gOut = *gOut_ptr_NCHW;
|
||||
//scalar_t val;
|
||||
//val = gOut * (-dx * (iy_se - iy) - dy * (ix_se - ix));
|
||||
safe_add_2d(grad_input.data, iy_nw, ix_nw, gInp_sH, gInp_sW, inp_H, inp_W, nw_tmp * gOut, NC_offset, grad_input_memory_span);
|
||||
//val = gOut * (+dx * (iy_sw - iy) - dy * (ix - ix_sw));
|
||||
safe_add_2d(grad_input.data, iy_ne, ix_ne, gInp_sH, gInp_sW, inp_H, inp_W, ne_tmp * gOut, NC_offset, grad_input_memory_span);
|
||||
//val = gOut * (-dx * (iy - iy_ne) + dy * (ix_ne - ix));
|
||||
safe_add_2d(grad_input.data, iy_sw, ix_sw, gInp_sH, gInp_sW, inp_H, inp_W, sw_tmp * gOut, NC_offset, grad_input_memory_span);
|
||||
//val = gOut * (+dx * (iy - iy_nw) + dy * (ix - ix_nw));
|
||||
safe_add_2d(grad_input.data, iy_se, ix_se, gInp_sH, gInp_sW, inp_H, inp_W, se_tmp * gOut, NC_offset, grad_input_memory_span);
|
||||
|
||||
scalar_t dxy = nw_val - ne_val - sw_val + se_val;
|
||||
// Computing gradient wrt grid_x = grad2_grad_input * y * gOut + grad2_grad_grid_y * val * gOut
|
||||
gix += gOut * (-g2_nw_val * (iy_se - iy) + g2_ne_val * (iy_sw - iy)
|
||||
-g2_sw_val * (iy - iy_ne) + g2_se_val * (iy - iy_nw));
|
||||
gix += gOut * dy * dxy;
|
||||
|
||||
// Computing gradient wrt grid_y = grad2_grad_input * x * gOut + grad2_grad_grid_x * val * gOut
|
||||
giy += gOut * (-g2_nw_val * (ix_se - ix) - g2_ne_val * (ix - ix_sw)
|
||||
+g2_sw_val * (ix_ne - ix) + g2_se_val * (ix - ix_nw));
|
||||
giy += gOut * dx * dxy;
|
||||
}
|
||||
|
||||
gGrid_ptr_NHW[0] = gix * gix_mult;
|
||||
gGrid_ptr_NHW[1] = giy * giy_mult;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t>
|
||||
C10_LAUNCH_BOUNDS_1(256)
|
||||
__global__ void grid_sampler_3d_grad2_kernel(
|
||||
const index_t nthreads,
|
||||
TensorInfo<scalar_t, index_t> grad2_grad_input,
|
||||
TensorInfo<scalar_t, index_t> grad2_grad_grid,
|
||||
TensorInfo<scalar_t, index_t> grad_output,
|
||||
TensorInfo<scalar_t, index_t> input,
|
||||
TensorInfo<scalar_t, index_t> grid,
|
||||
TensorInfo<scalar_t, index_t> grad_grad_output,
|
||||
TensorInfo<scalar_t, index_t> grad_input,
|
||||
TensorInfo<scalar_t, index_t> grad_grid,
|
||||
const GridSamplerPadding padding_mode,
|
||||
bool align_corners,
|
||||
const index_t grad_input_memory_span) {
|
||||
|
||||
index_t C = input.sizes[1];
|
||||
index_t inp_D = input.sizes[2];
|
||||
index_t inp_H = input.sizes[3];
|
||||
index_t inp_W = input.sizes[4];
|
||||
|
||||
index_t out_D = grid.sizes[1];
|
||||
index_t out_H = grid.sizes[2];
|
||||
index_t out_W = grid.sizes[3];
|
||||
|
||||
index_t g2inp_sN = grad2_grad_input.strides[0];
|
||||
index_t g2inp_sC = grad2_grad_input.strides[1];
|
||||
index_t g2inp_sD = grad2_grad_input.strides[2];
|
||||
index_t g2inp_sH = grad2_grad_input.strides[3];
|
||||
index_t g2inp_sW = grad2_grad_input.strides[4];
|
||||
|
||||
index_t g2grid_sN = grad2_grad_grid.strides[0];
|
||||
index_t g2grid_sD = grad2_grad_grid.strides[1];
|
||||
index_t g2grid_sH = grad2_grad_grid.strides[2];
|
||||
index_t g2grid_sW = grad2_grad_grid.strides[3];
|
||||
index_t g2grid_sCoor = grad2_grad_grid.strides[4];
|
||||
|
||||
index_t gOut_sN = grad_output.strides[0];
|
||||
index_t gOut_sC = grad_output.strides[1];
|
||||
index_t gOut_sD = grad_output.strides[2];
|
||||
index_t gOut_sH = grad_output.strides[3];
|
||||
index_t gOut_sW = grad_output.strides[4];
|
||||
|
||||
index_t inp_sN = input.strides[0];
|
||||
index_t inp_sC = input.strides[1];
|
||||
index_t inp_sD = input.strides[2];
|
||||
index_t inp_sH = input.strides[3];
|
||||
index_t inp_sW = input.strides[4];
|
||||
|
||||
index_t grid_sN = grid.strides[0];
|
||||
index_t grid_sD = grid.strides[1];
|
||||
index_t grid_sH = grid.strides[2];
|
||||
index_t grid_sW = grid.strides[3];
|
||||
index_t grid_sCoor = grid.strides[4];
|
||||
|
||||
index_t gInp_sN = grad_input.strides[0];
|
||||
index_t gInp_sC = grad_input.strides[1];
|
||||
index_t gInp_sD = grad_input.strides[2];
|
||||
index_t gInp_sH = grad_input.strides[3];
|
||||
index_t gInp_sW = grad_input.strides[4];
|
||||
|
||||
index_t gGrid_sW = grad_grid.strides[3];
|
||||
|
||||
index_t ggOut_sN = grad_grad_output.strides[0];
|
||||
index_t ggOut_sC = grad_grad_output.strides[1];
|
||||
index_t ggOut_sD = grad_grad_output.strides[2];
|
||||
index_t ggOut_sH = grad_grad_output.strides[3];
|
||||
index_t ggOut_sW = grad_grad_output.strides[4];
|
||||
|
||||
CUDA_KERNEL_LOOP_TYPE(index, nthreads, index_t) {
|
||||
const index_t w = index % out_W;
|
||||
const index_t h = (index / out_W) % out_H;
|
||||
const index_t d = (index / (out_H * out_W)) % out_D;
|
||||
const index_t n = index / (out_D * out_H * out_W);
|
||||
|
||||
/* Grid related staff */
|
||||
index_t grid_offset = n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW;
|
||||
|
||||
// get the corresponding input x, y co-ordinates from grid
|
||||
scalar_t ix = grid.data[grid_offset];
|
||||
scalar_t iy = grid.data[grid_offset + grid_sCoor];
|
||||
scalar_t iz = grid.data[grid_offset + 2 * grid_sCoor];
|
||||
|
||||
// multipliers for gradients on ix and iy
|
||||
scalar_t gix_mult, giy_mult, giz_mult;
|
||||
ix = grid_sampler_compute_source_index_set_grad(ix, inp_W, padding_mode, align_corners, &gix_mult);
|
||||
iy = grid_sampler_compute_source_index_set_grad(iy, inp_H, padding_mode, align_corners, &giy_mult);
|
||||
iz = grid_sampler_compute_source_index_set_grad(iz, inp_D, padding_mode, align_corners, &giz_mult);
|
||||
|
||||
// get NE, NW, SE, SW pixel values from (x, y)
|
||||
index_t ix_tnw = static_cast<index_t>(::floor(ix));
|
||||
index_t iy_tnw = static_cast<index_t>(::floor(iy));
|
||||
index_t iz_tnw = static_cast<index_t>(::floor(iz));
|
||||
|
||||
index_t ix_tne = ix_tnw + 1;
|
||||
index_t iy_tne = iy_tnw;
|
||||
index_t iz_tne = iz_tnw;
|
||||
|
||||
index_t ix_tsw = ix_tnw;
|
||||
index_t iy_tsw = iy_tnw + 1;
|
||||
index_t iz_tsw = iz_tnw;
|
||||
|
||||
index_t ix_tse = ix_tnw + 1;
|
||||
index_t iy_tse = iy_tnw + 1;
|
||||
index_t iz_tse = iz_tnw;
|
||||
|
||||
index_t ix_bnw = ix_tnw;
|
||||
index_t iy_bnw = iy_tnw;
|
||||
index_t iz_bnw = iz_tnw + 1;
|
||||
|
||||
index_t ix_bne = ix_tnw + 1;
|
||||
index_t iy_bne = iy_tnw;
|
||||
index_t iz_bne = iz_tnw + 1;
|
||||
|
||||
index_t ix_bsw = ix_tnw;
|
||||
index_t iy_bsw = iy_tnw + 1;
|
||||
index_t iz_bsw = iz_tnw + 1;
|
||||
|
||||
index_t ix_bse = ix_tnw + 1;
|
||||
index_t iy_bse = iy_tnw + 1;
|
||||
index_t iz_bse = iz_tnw + 1;
|
||||
|
||||
// get surfaces to each neighbor:
|
||||
scalar_t tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
|
||||
scalar_t tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
|
||||
scalar_t tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
|
||||
scalar_t tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
|
||||
scalar_t bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
|
||||
scalar_t bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
|
||||
scalar_t bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
|
||||
scalar_t bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
|
||||
|
||||
/* grad2_grad_input related init */
|
||||
scalar_t *g2_inp_ptr_NC = grad2_grad_input.data + n * g2inp_sN;
|
||||
|
||||
/* grad2_grad_grid related init */
|
||||
grid_offset = n * g2grid_sN + d * g2grid_sD + h * g2grid_sH + w * g2grid_sW;
|
||||
scalar_t dx = grad2_grad_grid.data[grid_offset];
|
||||
scalar_t dy = grad2_grad_grid.data[grid_offset + g2grid_sCoor];
|
||||
scalar_t dz = grad2_grad_grid.data[grid_offset + 2 * g2grid_sCoor];
|
||||
|
||||
dx = dx * gix_mult;
|
||||
dy = dy * giy_mult;
|
||||
dz = dz * giz_mult;
|
||||
|
||||
/* grad_output related init */
|
||||
scalar_t *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
|
||||
|
||||
/* input related init */
|
||||
scalar_t *inp_ptr_NC = input.data + n * inp_sN;
|
||||
|
||||
/* grad_grad_output related init */
|
||||
scalar_t *ggOut_ptr_NCDHW = grad_grad_output.data + n * ggOut_sN + d * ggOut_sD + h * ggOut_sH + w * ggOut_sW;
|
||||
|
||||
/* grad_input related init */
|
||||
index_t NC_offset = n * gInp_sN;
|
||||
|
||||
/* grad_grid related init */
|
||||
scalar_t *gGrid_ptr_NDHW = grad_grid.data + index * gGrid_sW;
|
||||
scalar_t gix = static_cast<scalar_t>(0), giy = static_cast<scalar_t>(0), giz = static_cast<scalar_t>(0);
|
||||
|
||||
scalar_t tnw_val, tne_val, tsw_val, tse_val, bnw_val, bne_val, bsw_val, bse_val;
|
||||
scalar_t g2_tnw_val, g2_tne_val, g2_tsw_val, g2_tse_val, g2_bnw_val, g2_bne_val, g2_bsw_val, g2_bse_val;
|
||||
|
||||
scalar_t zero = static_cast<scalar_t>(0);
|
||||
for (index_t c = 0; c < C;
|
||||
++c,
|
||||
g2_inp_ptr_NC += g2inp_sC,
|
||||
inp_ptr_NC += inp_sC,
|
||||
NC_offset += gInp_sC,
|
||||
gOut_ptr_NCDHW += gOut_sC,
|
||||
ggOut_ptr_NCDHW += ggOut_sC) {
|
||||
|
||||
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
|
||||
tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
|
||||
g2_tnw_val = g2_inp_ptr_NC[iz_tnw * g2inp_sD + iy_tnw * g2inp_sH + ix_tnw * g2inp_sW];
|
||||
} else {
|
||||
tnw_val = zero;
|
||||
g2_tnw_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
|
||||
tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
|
||||
g2_tne_val = g2_inp_ptr_NC[iz_tne * g2inp_sD + iy_tne * g2inp_sH + ix_tne * g2inp_sW];
|
||||
} else {
|
||||
tne_val = zero;
|
||||
g2_tne_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
|
||||
tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
|
||||
g2_tsw_val = g2_inp_ptr_NC[iz_tsw * g2inp_sD + iy_tsw * g2inp_sH + ix_tsw * g2inp_sW];
|
||||
} else {
|
||||
tsw_val = zero;
|
||||
g2_tsw_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
|
||||
tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
|
||||
g2_tse_val = g2_inp_ptr_NC[iz_tse * g2inp_sD + iy_tse * g2inp_sH + ix_tse * g2inp_sW];
|
||||
} else {
|
||||
tse_val = zero;
|
||||
g2_tse_val = zero;
|
||||
}
|
||||
|
||||
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
|
||||
bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
|
||||
g2_bnw_val = g2_inp_ptr_NC[iz_bnw * g2inp_sD + iy_bnw * g2inp_sH + ix_bnw * g2inp_sW];
|
||||
} else {
|
||||
bnw_val = zero;
|
||||
g2_bnw_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
|
||||
bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
|
||||
g2_bne_val = g2_inp_ptr_NC[iz_bne * g2inp_sD + iy_bne * g2inp_sH + ix_bne * g2inp_sW];
|
||||
} else {
|
||||
bne_val = zero;
|
||||
g2_bne_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
|
||||
bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
|
||||
g2_bsw_val = g2_inp_ptr_NC[iz_bsw * g2inp_sD + iy_bsw * g2inp_sH + ix_bsw * g2inp_sW];
|
||||
} else {
|
||||
bsw_val = zero;
|
||||
g2_bsw_val = zero;
|
||||
}
|
||||
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
|
||||
bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
|
||||
g2_bse_val = g2_inp_ptr_NC[iz_bse * g2inp_sD + iy_bse * g2inp_sH + ix_bse * g2inp_sW];
|
||||
} else {
|
||||
bse_val = zero;
|
||||
g2_bse_val = zero;
|
||||
}
|
||||
|
||||
// Computing gradient wrt to grad_output =
|
||||
// grad2_grad_input * x * y * z
|
||||
*ggOut_ptr_NCDHW = static_cast<scalar_t>(0);
|
||||
*ggOut_ptr_NCDHW += g2_tnw_val * tnw + g2_tne_val * tne + g2_tsw_val * tsw + g2_tse_val * tse
|
||||
+g2_bnw_val * bnw + g2_bne_val * bne + g2_bsw_val * bsw + g2_bse_val * bse;
|
||||
|
||||
// +val * (grad2_grad_grid_x * y * z + grad2_grad_grid_y * x * z + grad2_grad_grid_z * x * y)
|
||||
scalar_t tnw_tmp = (-dx * (iy_bse - iy) * (iz_bse - iz) - dy * (ix_bse - ix) * (iz_bse - iz) - dz * (ix_bse - ix) * (iy_bse - iy));
|
||||
scalar_t tne_tmp = (+dx * (iy_bsw - iy) * (iz_bsw - iz) - dy * (ix - ix_bsw) * (iz_bsw - iz) - dz * (ix - ix_bsw) * (iy_bsw - iy));
|
||||
scalar_t tsw_tmp = (-dx * (iy - iy_bne) * (iz_bne - iz) + dy * (ix_bne - ix) * (iz_bne - iz) - dz * (ix_bne - ix) * (iy - iy_bne));
|
||||
scalar_t tse_tmp = (+dx * (iy - iy_bnw) * (iz_bnw - iz) + dy * (ix - ix_bnw) * (iz_bnw - iz) - dz * (ix - ix_bnw) * (iy - iy_bnw));
|
||||
scalar_t bnw_tmp = (-dx * (iy_tse - iy) * (iz - iz_tse) - dy * (ix_tse - ix) * (iz - iz_tse) + dz * (ix_tse - ix) * (iy_tse - iy));
|
||||
scalar_t bne_tmp = (+dx * (iy_tsw - iy) * (iz - iz_tsw) - dy * (ix - ix_tsw) * (iz - iz_tsw) + dz * (ix - ix_tsw) * (iy_tsw - iy));
|
||||
scalar_t bsw_tmp = (-dx * (iy - iy_tne) * (iz - iz_tne) + dy * (ix_tne - ix) * (iz - iz_tne) + dz * (ix_tne - ix) * (iy - iy_tne));
|
||||
scalar_t bse_tmp = (+dx * (iy - iy_tnw) * (iz - iz_tnw) + dy * (ix - ix_tnw) * (iz - iz_tnw) + dz * (ix - ix_tnw) * (iy - iy_tnw));
|
||||
|
||||
*ggOut_ptr_NCDHW += tnw_val * tnw_tmp + tne_val * tne_tmp + tsw_val * tsw_tmp + tse_val * tse_tmp
|
||||
+bnw_val * bnw_tmp + bne_val * bne_tmp + bsw_val * bsw_tmp + bse_val * bse_tmp;
|
||||
|
||||
// Computing gradient wrt input = grad2_grad_grid_x * grad_output * y * z + grad2_grad_grid_y * grad_output * x * z +
|
||||
// grad2_grad_grid_z * grad_output * y * z
|
||||
scalar_t gOut = *gOut_ptr_NCDHW;
|
||||
|
||||
safe_add_3d(grad_input.data, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
safe_add_3d(grad_input.data, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse_tmp * gOut,
|
||||
NC_offset, grad_input_memory_span);
|
||||
|
||||
//Computing gradient wrt grid
|
||||
scalar_t dxy = (tnw_val * (iz_bse - iz) - tne_val * (iz_bsw - iz)
|
||||
-tsw_val * (iz_bne - iz) + tse_val * (iz_bnw - iz)
|
||||
+bnw_val * (iz - iz_tse) - bne_val * (iz - iz_tsw)
|
||||
-bsw_val * (iz - iz_tne) + bse_val * (iz - iz_tnw));
|
||||
|
||||
scalar_t dxz = (tnw_val * (iy_bse - iy) - tne_val * (iy_bsw - iy)
|
||||
+tsw_val * (iy - iy_bne) - tse_val * (iy - iy_bnw)
|
||||
-bnw_val * (iy_tse - iy) + bne_val * (iy_tsw - iy)
|
||||
-bsw_val * (iy - iy_tne) + bse_val * (iy - iy_tnw));
|
||||
|
||||
scalar_t dyz = (tnw_val * (ix_bse - ix) + tne_val * (ix - ix_bsw)
|
||||
-tsw_val * (ix_bne - ix) - tse_val * (ix - ix_bnw)
|
||||
-bnw_val * (ix_tse - ix) - bne_val * (ix - ix_tsw)
|
||||
+bsw_val * (ix_tne - ix) + bse_val * (ix - ix_tnw));
|
||||
|
||||
|
||||
// Computing gradient wrt grid_x =
|
||||
// grad2_grad_input * z * y * gOut
|
||||
gix += gOut * (-g2_tnw_val * (iy_bse - iy) * (iz_bse - iz) + g2_tne_val * (iy_bsw - iy) * (iz_bsw - iz)
|
||||
-g2_tsw_val * (iy - iy_bne) * (iz_bne - iz) + g2_tse_val * (iy - iy_bnw) * (iz_bnw - iz)
|
||||
-g2_bnw_val * (iy_tse - iy) * (iz - iz_tse) + g2_bne_val * (iy_tsw - iy) * (iz - iz_tsw)
|
||||
-g2_bsw_val * (iy - iy_tne) * (iz - iz_tne) + g2_bse_val * (iy - iy_tnw) * (iz - iz_tnw));
|
||||
|
||||
//+ grad2_grad_grid_z * y * val * gOut + grad2_grad_grid_y * z * val * gOut
|
||||
gix += gOut * (dz * dxz + dy * dxy);
|
||||
|
||||
// Computing gradient wrt grid_y =
|
||||
// grad2_grad_input * x * z * gOut
|
||||
giy += gOut * (-g2_tnw_val * (ix_bse - ix) * (iz_bse - iz) - g2_tne_val * (ix - ix_bsw) * (iz_bsw - iz)
|
||||
+g2_tsw_val * (ix_bne - ix) * (iz_bne - iz) + g2_tse_val * (ix - ix_bnw) * (iz_bnw - iz)
|
||||
-g2_bnw_val * (ix_tse - ix) * (iz - iz_tse) - g2_bne_val * (ix - ix_tsw) * (iz - iz_tsw)
|
||||
+g2_bsw_val * (ix_tne - ix) * (iz - iz_tne) + g2_bse_val * (ix - ix_tnw) * (iz - iz_tnw));
|
||||
//+ grad2_grad_grid_x * z * val * gOut + grad2_grad_grid_z * x * val * gOut
|
||||
giy += gOut * (dx * dxy + dz * dyz);
|
||||
|
||||
// Computing gradient wrt grid_z =
|
||||
// grad2_grad_input * x * y * gOut
|
||||
giz += gOut * (-g2_tnw_val * (ix_bse - ix) * (iy_bse - iy) - g2_tne_val * (ix - ix_bsw) * (iy_bsw - iy)
|
||||
-g2_tsw_val * (ix_bne - ix) * (iy - iy_bne) - g2_tse_val * (ix - ix_bnw) * (iy - iy_bnw)
|
||||
+g2_bnw_val * (ix_tse - ix) * (iy_tse - iy) + g2_bne_val * (ix - ix_tsw) * (iy_tsw - iy)
|
||||
+g2_bsw_val * (ix_tne - ix) * (iy - iy_tne) + g2_bse_val * (ix - ix_tnw) * (iy - iy_tnw));
|
||||
//+ grad2_grad_grid_x * y * val * gOut + grad2_grad_grid_y * x * val * gOut
|
||||
giz += gOut * (dx * dxz + dy * dyz);
|
||||
}
|
||||
|
||||
gGrid_ptr_NDHW[0] = gix * gix_mult;
|
||||
gGrid_ptr_NDHW[1] = giy * giy_mult;
|
||||
gGrid_ptr_NDHW[2] = giz * giz_mult;
|
||||
}
|
||||
}}
|
||||
|
||||
|
||||
std::vector<torch::Tensor> grid_sample2d_cuda_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners) {
|
||||
|
||||
const auto batch_size = input.size(0);
|
||||
const auto C = input.size(1);
|
||||
const auto H_IN = input.size(2);
|
||||
const auto W_IN = input.size(3);
|
||||
|
||||
const auto H_OUT = grid.size(1);
|
||||
const auto W_OUT = grid.size(2);
|
||||
|
||||
torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
|
||||
int64_t count = batch_size * H_OUT * W_OUT;
|
||||
|
||||
if (count > 0) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_2d_grad2_cuda", [&] {
|
||||
if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
|
||||
canUse32BitIndexMath(grad_output)) {
|
||||
grid_sampler_2d_grad2_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
static_cast<int>(count),
|
||||
getTensorInfo<scalar_t, int>(grad2_grad_input),
|
||||
getTensorInfo<scalar_t, int>(grad2_grad_grid),
|
||||
getTensorInfo<scalar_t, int>(grad_output),
|
||||
getTensorInfo<scalar_t, int>(input),
|
||||
getTensorInfo<scalar_t, int>(grid),
|
||||
getTensorInfo<scalar_t, int>(grad_grad_output),
|
||||
getTensorInfo<scalar_t, int>(grad_input),
|
||||
getTensorInfo<scalar_t, int>(grad_grid),
|
||||
static_cast<GridSamplerPadding>(padding_mode),
|
||||
align_corners,
|
||||
static_cast<int>(grad_input.numel()));
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
grid_sampler_2d_grad2_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
getTensorInfo<scalar_t, int64_t>(grad2_grad_input),
|
||||
getTensorInfo<scalar_t, int64_t>(grad2_grad_grid),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_output),
|
||||
getTensorInfo<scalar_t, int64_t>(input),
|
||||
getTensorInfo<scalar_t, int64_t>(grid),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_grad_output),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_input),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_grid),
|
||||
static_cast<GridSamplerPadding>(padding_mode),
|
||||
align_corners,
|
||||
grad_input.numel());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return {grad_grad_output, grad_input, grad_grid};
|
||||
}
|
||||
|
||||
std::vector<torch::Tensor> grid_sample3d_cuda_grad2(
|
||||
const torch::Tensor &grad2_grad_input,
|
||||
const torch::Tensor &grad2_grad_grid,
|
||||
const torch::Tensor &grad_output,
|
||||
const torch::Tensor &input,
|
||||
const torch::Tensor &grid,
|
||||
bool padding_mode,
|
||||
bool align_corners) {
|
||||
|
||||
const auto batch_size = input.size(0);
|
||||
const auto C = input.size(1);
|
||||
const auto D_IN = input.size(2);
|
||||
const auto H_IN = input.size(3);
|
||||
const auto W_IN = input.size(4);
|
||||
|
||||
const auto D_OUT = grid.size(1);
|
||||
const auto H_OUT = grid.size(2);
|
||||
const auto W_OUT = grid.size(3);
|
||||
|
||||
torch::Tensor grad_grad_output = torch::zeros_like(grad_output, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
torch::Tensor grad_input = torch::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
torch::Tensor grad_grid = torch::zeros_like(grid, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
|
||||
int64_t count = batch_size * D_OUT * H_OUT * W_OUT;
|
||||
|
||||
if (count > 0) {
|
||||
AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "grid_sampler_3d_grad2_cuda", [&] {
|
||||
if (canUse32BitIndexMath(input) && canUse32BitIndexMath(grid) &&
|
||||
canUse32BitIndexMath(grad_output)) {
|
||||
grid_sampler_3d_grad2_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
static_cast<int>(count),
|
||||
getTensorInfo<scalar_t, int>(grad2_grad_input),
|
||||
getTensorInfo<scalar_t, int>(grad2_grad_grid),
|
||||
getTensorInfo<scalar_t, int>(grad_output),
|
||||
getTensorInfo<scalar_t, int>(input),
|
||||
getTensorInfo<scalar_t, int>(grid),
|
||||
getTensorInfo<scalar_t, int>(grad_grad_output),
|
||||
getTensorInfo<scalar_t, int>(grad_input),
|
||||
getTensorInfo<scalar_t, int>(grad_grid),
|
||||
static_cast<GridSamplerPadding>(padding_mode),
|
||||
align_corners,
|
||||
static_cast<int>(grad_input.numel()));
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
grid_sampler_3d_grad2_kernel<scalar_t>
|
||||
<<<GET_BLOCKS(count, 256), 256, 0, at::cuda::getCurrentCUDAStream()>>>(
|
||||
count,
|
||||
getTensorInfo<scalar_t, int64_t>(grad2_grad_input),
|
||||
getTensorInfo<scalar_t, int64_t>(grad2_grad_grid),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_output),
|
||||
getTensorInfo<scalar_t, int64_t>(input),
|
||||
getTensorInfo<scalar_t, int64_t>(grid),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_grad_output),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_input),
|
||||
getTensorInfo<scalar_t, int64_t>(grad_grid),
|
||||
static_cast<GridSamplerPadding>(padding_mode),
|
||||
align_corners,
|
||||
grad_input.numel());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
return {grad_grad_output, grad_input, grad_grid};
|
||||
}
|
||||
|
||||
}}
|
93
svrm/ldm/modules/rendering_neus/third_party/pytorch_ssim/__init__.py
vendored
Normal file
@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Variable
|
||||
import numpy as np
|
||||
from math import exp
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window(window_size, channel):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
|
||||
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
|
||||
return window
|
||||
|
||||
|
||||
def _ssim(img1, img2, window, window_size, channel, use_padding, size_average=True):
|
||||
|
||||
if use_padding:
|
||||
padding_size = window_size // 2
|
||||
else:
|
||||
padding_size = 0
|
||||
|
||||
mu1 = F.conv2d(img1, window, padding=padding_size, groups=channel)
|
||||
mu2 = F.conv2d(img2, window, padding=padding_size, groups=channel)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv2d(img1 * img1, window, padding=padding_size, groups=channel) - mu1_sq
|
||||
sigma2_sq = F.conv2d(img2 * img2, window, padding=padding_size, groups=channel) - mu2_sq
|
||||
sigma12 = F.conv2d(img1 * img2, window, padding=padding_size, groups=channel) - mu1_mu2
|
||||
|
||||
C1 = 0.01 ** 2
|
||||
C2 = 0.03 ** 2
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
||||
|
||||
if size_average:
|
||||
return ssim_map.mean()
|
||||
else:
|
||||
return ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
|
||||
class SSIM(torch.nn.Module):
|
||||
def __init__(self, window_size=11, use_padding=True, size_average=True):
|
||||
super(SSIM, self).__init__()
|
||||
self.window_size = window_size
|
||||
self.size_average = size_average
|
||||
self.use_padding = use_padding
|
||||
self.channel = 1
|
||||
self.window = create_window(window_size, self.channel)
|
||||
|
||||
def forward(self, img1, img2):
|
||||
(_, channel, _, _) = img1.size()
|
||||
|
||||
if channel == self.channel and self.window.data.type() == img1.data.type():
|
||||
window = self.window
|
||||
else:
|
||||
window = create_window(self.window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
self.window = window
|
||||
self.channel = channel
|
||||
|
||||
return _ssim(img1, img2, window, self.window_size, channel, self.use_padding, self.size_average)
|
||||
|
||||
|
||||
def ssim(img1, img2, use_padding=True, window_size=11, size_average=True):
|
||||
"""SSIM only defined at intensity channel. For RGB or YUV or other image format, this function computes SSIm at each
|
||||
channel and averge them.
|
||||
:param img1: (B, C, H, W) float32 in [0, 1]
|
||||
:param img2: (B, C, H, W) float32 in [0, 1]
|
||||
:param use_padding: we use conv2d when we compute mean and var for each patch, this use_padding is for that conv2d.
|
||||
:param window_size: patch size
|
||||
:param size_average:
|
||||
:return: a tensor that contains only one scalar.
|
||||
"""
|
||||
(_, channel, _, _) = img1.size()
|
||||
window = create_window(window_size, channel)
|
||||
|
||||
if img1.is_cuda:
|
||||
window = window.cuda(img1.get_device())
|
||||
window = window.type_as(img1)
|
||||
|
||||
return _ssim(img1, img2, window, window_size, channel, use_padding, size_average)
|
9
svrm/ldm/modules/rendering_neus/utils/__init__.py
Normal file
@ -0,0 +1,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
118
svrm/ldm/modules/rendering_neus/utils/math_utils.py
Normal file
@ -0,0 +1,118 @@
|
||||
# MIT License
|
||||
|
||||
# Copyright (c) 2022 Petr Kellnhofer
|
||||
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import torch
|
||||
|
||||
def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Left-multiplies MxM @ NxM. Returns NxM.
|
||||
"""
|
||||
res = torch.matmul(vectors4, matrix.T)
|
||||
return res
|
||||
|
||||
|
||||
def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Normalize vector lengths.
|
||||
"""
|
||||
return vectors / (torch.norm(vectors, dim=-1, keepdim=True))
|
||||
|
||||
def torch_dot(x: torch.Tensor, y: torch.Tensor):
|
||||
"""
|
||||
Dot product of two tensors.
|
||||
"""
|
||||
return (x * y).sum(-1)
|
||||
|
||||
|
||||
def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length):
|
||||
"""
|
||||
Author: Petr Kellnhofer
|
||||
Intersects rays with the [-1, 1] NDC volume.
|
||||
Returns min and max distance of entry.
|
||||
Returns -1 for no intersection.
|
||||
https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection
|
||||
"""
|
||||
o_shape = rays_o.shape
|
||||
rays_o = rays_o.detach().reshape(-1, 3)
|
||||
rays_d = rays_d.detach().reshape(-1, 3)
|
||||
|
||||
|
||||
bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)]
|
||||
bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)]
|
||||
bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device)
|
||||
is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device)
|
||||
|
||||
# Precompute inverse for stability.
|
||||
invdir = 1 / rays_d
|
||||
sign = (invdir < 0).long()
|
||||
|
||||
# Intersect with YZ plane.
|
||||
tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
||||
tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0]
|
||||
|
||||
# Intersect with XZ plane.
|
||||
tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
||||
tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1]
|
||||
|
||||
# Resolve parallel rays.
|
||||
is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False
|
||||
|
||||
# Use the shortest intersection.
|
||||
tmin = torch.max(tmin, tymin)
|
||||
tmax = torch.min(tmax, tymax)
|
||||
|
||||
# Intersect with XY plane.
|
||||
tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
||||
tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2]
|
||||
|
||||
# Resolve parallel rays.
|
||||
is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False
|
||||
|
||||
# Use the shortest intersection.
|
||||
tmin = torch.max(tmin, tzmin)
|
||||
tmax = torch.min(tmax, tzmax)
|
||||
|
||||
# Mark invalid.
|
||||
tmin[torch.logical_not(is_valid)] = -1
|
||||
tmax[torch.logical_not(is_valid)] = -2
|
||||
|
||||
return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1)
|
||||
|
||||
|
||||
def linspace(start: torch.Tensor, stop: torch.Tensor, num: int):
|
||||
"""
|
||||
Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive.
|
||||
Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch.
|
||||
"""
|
||||
# create a tensor of 'num' steps from 0 to 1
|
||||
steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1)
|
||||
|
||||
# reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings
|
||||
# - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript
|
||||
# "cannot statically infer the expected size of a list in this contex", hence the code below
|
||||
for i in range(start.ndim):
|
||||
steps = steps.unsqueeze(-1)
|
||||
|
||||
# the output starts at 'start' and increments until 'stop' in each dimension
|
||||
out = start[None] + steps * (stop - start)[None]
|
||||
|
||||
return out
|
140
svrm/ldm/modules/rendering_neus/utils/ray_marcher.py
Normal file
@ -0,0 +1,140 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
# Modified by Zexin He
|
||||
# The modifications are subject to the same license as the original.
|
||||
|
||||
|
||||
"""
|
||||
The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths.
|
||||
Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!)
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LearnedVariance(nn.Module):
|
||||
def __init__(self, init_val):
|
||||
super(LearnedVariance, self).__init__()
|
||||
self.register_parameter("_inv_std", nn.Parameter(torch.tensor(init_val)))
|
||||
|
||||
@property
|
||||
def inv_std(self):
|
||||
val = torch.exp(self._inv_std * 10.0)
|
||||
return val
|
||||
|
||||
def forward(self, x):
|
||||
return torch.ones_like(x) * self.inv_std.clamp(1.0e-6, 1.0e6)
|
||||
|
||||
|
||||
class MipRayMarcher2(nn.Module):
|
||||
def __init__(self, activation_factory):
|
||||
super().__init__()
|
||||
self.activation_factory = activation_factory
|
||||
self.variance = LearnedVariance(0.3)
|
||||
self.cos_anneal_ratio = 1.0
|
||||
def get_alpha(self, sdf, normal, dirs, dists):
|
||||
# sdf: [N 1] normal: [N 3] dirs: [N 3] dists: [N 1]
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
inv_std = self.variance(sdf)
|
||||
|
||||
true_cos = (dirs * normal).sum(-1, keepdim=True)
|
||||
# "cos_anneal_ratio" grows from 0 to 1 in the beginning training iterations. The anneal strategy below makes
|
||||
# the cos value "not dead" at the beginning training iterations, for better convergence.
|
||||
iter_cos = -(
|
||||
F.relu(-true_cos * 0.5 + 0.5) * (1.0 - self.cos_anneal_ratio)
|
||||
+ F.relu(-true_cos) * self.cos_anneal_ratio
|
||||
) # always non-positive
|
||||
|
||||
# Estimate signed distances at section points
|
||||
estimated_next_sdf = sdf + iter_cos * dists * 0.5
|
||||
estimated_prev_sdf = sdf - iter_cos * dists * 0.5
|
||||
|
||||
prev_cdf = torch.sigmoid(estimated_prev_sdf * inv_std)
|
||||
next_cdf = torch.sigmoid(estimated_next_sdf * inv_std)
|
||||
|
||||
p = prev_cdf - next_cdf
|
||||
c = prev_cdf
|
||||
|
||||
alpha = ((p + 1e-5) / (c + 1e-5)).clip(0.0, 1.0)
|
||||
return alpha
|
||||
|
||||
def run_forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None):
|
||||
# depths: [B N_ray*N_sample 1]
|
||||
# sdfs: [B, N_ray, N_sample 1]
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
deltas = depths[:, :, 1:] - depths[:, :, :-1]
|
||||
colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2
|
||||
sdfs_mid = (sdfs[:, :, :-1] + sdfs[:, :, 1:]) / 2
|
||||
depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2
|
||||
normals_mid = (normals[:, :, :-1] + normals[:, :, 1:]) / 2
|
||||
|
||||
# zhaohx add for normal :
|
||||
real_normals_mid = (real_normals[:, :, :-1] + real_normals[:, :, 1:]) / 2
|
||||
|
||||
# # using factory mode for better usability
|
||||
# densities_mid = self.activation_factory(rendering_options)(densities_mid)
|
||||
|
||||
# density_delta = densities_mid * deltas
|
||||
|
||||
# alpha = 1 - torch.exp(-density_delta)
|
||||
|
||||
# import ipdb; ipdb.set_trace()
|
||||
dirs = ray_directions.unsqueeze(2).expand(-1, -1, sdfs_mid.shape[-2], -1)
|
||||
B, N_ray, N_sample, _ = sdfs_mid.shape
|
||||
alpha = self.get_alpha(sdfs_mid.reshape(-1, 1), normals_mid.reshape(-1, 3), dirs.reshape(-1, 3), deltas.reshape(-1, 1))
|
||||
alpha = alpha.reshape(B, N_ray, N_sample, -1)
|
||||
|
||||
alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2)
|
||||
weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1]
|
||||
|
||||
composite_rgb = torch.sum(weights * colors_mid, -2)
|
||||
weight_total = weights.sum(2)
|
||||
composite_depth = torch.sum(weights * depths_mid, -2) / weight_total
|
||||
|
||||
# clip the composite to min/max range of depths
|
||||
composite_depth = torch.nan_to_num(composite_depth, float('inf'))
|
||||
composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths))
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
# zhaohx add for normal :
|
||||
composite_normal = torch.sum(weights * real_normals_mid, -2) / weight_total
|
||||
composite_normal = torch.nan_to_num(composite_normal, float('inf'))
|
||||
composite_normal = torch.clamp(composite_normal, torch.min(real_normals), torch.max(real_normals))
|
||||
|
||||
if rendering_options.get('white_back', False):
|
||||
# composite_rgb = composite_rgb + 1 - weight_total
|
||||
# weight_total[weight_total < 0.5] = 0
|
||||
# composite_rgb = composite_rgb * weight_total + 1 - weight_total
|
||||
# now is this
|
||||
if bgcolor is None:
|
||||
composite_rgb = composite_rgb + 1 - weight_total
|
||||
# composite_rgb = composite_rgb * weight_total + 1 - weight_total
|
||||
else:
|
||||
# import pdb; pdb.set_trace()
|
||||
bgcolor = bgcolor.permute(0, 2, 3, 1).contiguous().view(composite_rgb.shape[0], -1, composite_rgb.shape[-1])
|
||||
composite_rgb = composite_rgb + (1 - weight_total) * bgcolor
|
||||
# composite_rgb = composite_rgb * weight_total + (1 - weight_total) * bgcolor
|
||||
# composite_rgb = composite_rgb
|
||||
# print('new white_back')
|
||||
|
||||
# rendered value scale is 0-1, comment out original mipnerf scaling
|
||||
# composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1)
|
||||
|
||||
return composite_rgb, composite_depth, weights, composite_normal
|
||||
|
||||
|
||||
def forward(self, colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor=None, real_normals=None):
|
||||
composite_rgb, composite_depth, weights, composite_normal = self.run_forward(colors, sdfs, depths, normals, ray_directions, rendering_options, bgcolor, real_normals)
|
||||
|
||||
return composite_rgb, composite_depth, weights, composite_normal
|
81
svrm/ldm/modules/rendering_neus/utils/ray_sampler.py
Normal file
@ -0,0 +1,81 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
# Modified by Zexin He
|
||||
# The modifications are subject to the same license as the original.
|
||||
|
||||
|
||||
"""
|
||||
The ray sampler is a module that takes in camera matrices and resolution and batches of rays.
|
||||
Expects cam2world matrices that use the OpenCV camera coordinate system conventions.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
class RaySampler(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None
|
||||
|
||||
|
||||
def forward(self, cam2world_matrix, intrinsics, render_size):
|
||||
"""
|
||||
Create batches of rays and return origins and directions.
|
||||
|
||||
cam2world_matrix: (N, 4, 4)
|
||||
intrinsics: (N, 3, 3)
|
||||
render_size: int
|
||||
|
||||
ray_origins: (N, M, 3)
|
||||
ray_dirs: (N, M, 2)
|
||||
"""
|
||||
|
||||
N, M = cam2world_matrix.shape[0], render_size**2
|
||||
cam_locs_world = cam2world_matrix[:, :3, 3]
|
||||
fx = intrinsics[:, 0, 0]
|
||||
fy = intrinsics[:, 1, 1]
|
||||
cx = intrinsics[:, 0, 2]
|
||||
cy = intrinsics[:, 1, 2]
|
||||
sk = intrinsics[:, 0, 1]
|
||||
|
||||
uv = torch.stack(torch.meshgrid(
|
||||
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
|
||||
torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device),
|
||||
indexing='ij',
|
||||
))
|
||||
uv = uv.flip(0).reshape(2, -1).transpose(1, 0)
|
||||
uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1)
|
||||
|
||||
x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size)
|
||||
y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size)
|
||||
z_cam = torch.ones((N, M), device=cam2world_matrix.device)
|
||||
|
||||
x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam
|
||||
y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam
|
||||
|
||||
cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1)
|
||||
|
||||
_opencv2blender = torch.tensor([
|
||||
[1, 0, 0, 0],
|
||||
[0, -1, 0, 0],
|
||||
[0, 0, -1, 0],
|
||||
[0, 0, 0, 1],
|
||||
], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1)
|
||||
|
||||
cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender)
|
||||
|
||||
world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3]
|
||||
|
||||
ray_dirs = world_rel_points - cam_locs_world[:, None, :]
|
||||
ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2)
|
||||
|
||||
ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1)
|
||||
|
||||
return ray_origins, ray_dirs
|
331
svrm/ldm/modules/rendering_neus/utils/renderer.py
Normal file
@ -0,0 +1,331 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
#
|
||||
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
# property and proprietary rights in and to this material, related
|
||||
# documentation and any modifications thereto. Any use, reproduction,
|
||||
# disclosure or distribution of this material and related documentation
|
||||
# without an express license agreement from NVIDIA CORPORATION or
|
||||
# its affiliates is strictly prohibited.
|
||||
#
|
||||
# Modified by Zexin He
|
||||
# The modifications are subject to the same license as the original.
|
||||
|
||||
|
||||
"""
|
||||
The renderer is a module that takes in rays, decides where to sample along each
|
||||
ray, and computes pixel colors using the volume rendering equation.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .ray_marcher import MipRayMarcher2
|
||||
from . import math_utils
|
||||
# from ldm.modules.rendering_neus.third_party.ops import grid_sample
|
||||
|
||||
def generate_planes():
|
||||
"""
|
||||
Defines planes by the three vectors that form the "axes" of the
|
||||
plane. Should work with arbitrary number of planes and planes of
|
||||
arbitrary orientation.
|
||||
|
||||
Bugfix reference: https://github.com/NVlabs/eg3d/issues/67
|
||||
"""
|
||||
return torch.tensor([[[1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, 0, 1]],
|
||||
[[1, 0, 0],
|
||||
[0, 0, 1],
|
||||
[0, 1, 0]],
|
||||
[[0, 0, 1],
|
||||
[0, 1, 0],
|
||||
[1, 0, 0]]], dtype=torch.float32)
|
||||
|
||||
def project_onto_planes(planes, coordinates):
|
||||
"""
|
||||
Does a projection of a 3D point onto a batch of 2D planes,
|
||||
returning 2D plane coordinates.
|
||||
|
||||
Takes plane axes of shape n_planes, 3, 3
|
||||
# Takes coordinates of shape N, M, 3
|
||||
# returns projections of shape N*n_planes, M, 2
|
||||
"""
|
||||
N, M, C = coordinates.shape
|
||||
n_planes, _, _ = planes.shape
|
||||
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
|
||||
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
|
||||
projections = torch.bmm(coordinates, inv_planes)
|
||||
return projections[..., :2]
|
||||
|
||||
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
|
||||
assert padding_mode == 'zeros'
|
||||
N, n_planes, C, H, W = plane_features.shape
|
||||
_, M, _ = coordinates.shape
|
||||
plane_features = plane_features.view(N*n_planes, C, H, W)
|
||||
|
||||
coordinates = (2/box_warp) * coordinates # add specific box bounds
|
||||
# print(coordinates.max(), coordinates.min())
|
||||
# import pdb; pdb.set_trace()
|
||||
|
||||
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
|
||||
# output_features = grid_sample.grid_sample_2d(plane_features, projected_coordinates.float().to(plane_features.device)).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
||||
|
||||
output_features = torch.nn.functional.grid_sample(plane_features.float(), projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
||||
return output_features
|
||||
|
||||
def sample_from_3dgrid(grid, coordinates):
|
||||
"""
|
||||
Expects coordinates in shape (batch_size, num_points_per_batch, 3)
|
||||
Expects grid in shape (1, channels, H, W, D)
|
||||
(Also works if grid has batch size)
|
||||
Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels)
|
||||
"""
|
||||
batch_size, n_coords, n_dims = coordinates.shape
|
||||
sampled_features = torch.nn.functional.grid_sample(grid.expand(batch_size, -1, -1, -1, -1),
|
||||
coordinates.reshape(batch_size, 1, 1, -1, n_dims),
|
||||
mode='bilinear', padding_mode='zeros', align_corners=False)
|
||||
N, C, H, W, D = sampled_features.shape
|
||||
sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C)
|
||||
return sampled_features
|
||||
|
||||
class ImportanceRenderer(torch.nn.Module):
|
||||
"""
|
||||
Modified original version to filter out-of-box samples as TensoRF does.
|
||||
|
||||
Reference:
|
||||
TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.activation_factory = self._build_activation_factory()
|
||||
self.ray_marcher = MipRayMarcher2(self.activation_factory)
|
||||
self.plane_axes = generate_planes()
|
||||
|
||||
def _build_activation_factory(self):
|
||||
def activation_factory(options: dict):
|
||||
if options['clamp_mode'] == 'softplus':
|
||||
return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better
|
||||
else:
|
||||
assert False, "Renderer only supports `clamp_mode`=`softplus`!"
|
||||
return activation_factory
|
||||
|
||||
def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor,
|
||||
planes: torch.Tensor, decoder: nn.Module, rendering_options: dict):
|
||||
"""
|
||||
Additional filtering is applied to filter out-of-box samples.
|
||||
Modifications made by Zexin He.
|
||||
"""
|
||||
|
||||
# context related variables
|
||||
batch_size, num_rays, samples_per_ray, _ = depths.shape
|
||||
device = depths.device
|
||||
|
||||
# define sample points with depths
|
||||
sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3)
|
||||
sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3)
|
||||
# print(f'min bbox: {sample_coordinates.min()}, max bbox: {sample_coordinates.max()}')
|
||||
# import pdb; pdb.set_trace()
|
||||
# filter out-of-box samples
|
||||
mask_inbox = \
|
||||
(rendering_options['sampler_bbox_min'] <= sample_coordinates) & \
|
||||
(sample_coordinates <= rendering_options['sampler_bbox_max'])
|
||||
mask_inbox = mask_inbox.all(-1)
|
||||
|
||||
# forward model according to all samples
|
||||
_out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options)
|
||||
|
||||
# set out-of-box samples to zeros(rgb) & -inf(sigma)
|
||||
SAFE_GUARD = 3
|
||||
DATA_TYPE = _out['sdf'].dtype
|
||||
colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
|
||||
normals_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE)
|
||||
sdfs_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD
|
||||
|
||||
# print(DATA_TYPE)
|
||||
# import pdb; pdb.set_trace()
|
||||
# colors_pass[mask_inbox], sdfs_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sdf'][mask_inbox]
|
||||
colors_pass[mask_inbox], sdfs_pass = _out['rgb'][mask_inbox], _out['sdf']
|
||||
normals_pass = _out['normal']
|
||||
|
||||
# reshape back
|
||||
colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1])
|
||||
sdfs_pass = sdfs_pass.reshape(batch_size, num_rays, samples_per_ray, sdfs_pass.shape[-1])
|
||||
normals_pass = normals_pass.reshape(batch_size, num_rays, samples_per_ray, normals_pass.shape[-1])
|
||||
|
||||
return colors_pass, sdfs_pass, normals_pass, _out['sdf_grad']
|
||||
|
||||
def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options, bgcolor=None):
|
||||
# self.plane_axes = self.plane_axes.to(ray_origins.device)
|
||||
|
||||
if rendering_options['ray_start'] == 'auto' == rendering_options['ray_end']:
|
||||
ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) # [1, N_ray, 1]
|
||||
is_ray_valid = ray_end > ray_start
|
||||
if torch.any(is_ray_valid).item():
|
||||
ray_start[~is_ray_valid] = ray_start[is_ray_valid].min()
|
||||
ray_end[~is_ray_valid] = ray_start[is_ray_valid].max()
|
||||
depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) # [1, N_ray, N_sample, 1]】
|
||||
else:
|
||||
# Create stratified depth samples
|
||||
depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling'])
|
||||
|
||||
|
||||
# Coarse Pass
|
||||
colors_coarse, sdfs_coarse, normals_coarse, sdf_grad = self._forward_pass(
|
||||
depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins,
|
||||
planes=planes, decoder=decoder, rendering_options=rendering_options)
|
||||
|
||||
|
||||
# Fine Pass
|
||||
N_importance = rendering_options['depth_resolution_importance']
|
||||
# TODO
|
||||
if N_importance > 0:
|
||||
_, _, weights = self.ray_marcher(colors_coarse, sdfs_coarse, depths_coarse, sdf_grad.reshape(*normals_coarse.shape), ray_directions, rendering_options, bgcolor)
|
||||
|
||||
depths_fine = self.sample_importance(depths_coarse, weights, N_importance)
|
||||
|
||||
colors_fine, densities_fine = self._forward_pass(
|
||||
depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins,
|
||||
planes=planes, decoder=decoder, rendering_options=rendering_options)
|
||||
|
||||
all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse,
|
||||
depths_fine, colors_fine, densities_fine)
|
||||
####
|
||||
# dists = depths_coarse[:, :, 1:, :] - depths_coarse[:, :, :-1, :]
|
||||
# inter = (ray_end - ray_start) / ( rendering_options['depth_resolution'] + rendering_options['depth_resolution_importance'] - 1) # [1, N_ray, 1]
|
||||
# dists = torch.cat([dists, inter.unsqueeze(2), 2])
|
||||
####
|
||||
|
||||
# Aggregate
|
||||
rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options, bgcolor)
|
||||
else:
|
||||
# # import pdb; pdb.set_trace()
|
||||
# dists = depths_coarse[:, :, 1:, :] - depths_coarse[:, :, :-1, :]
|
||||
# inter = (ray_end - ray_start) / ( rendering_options['depth_resolution'] - 1) # [1, N_ray, 1]
|
||||
# dists = torch.cat([dists, inter.unsqueeze(2)], 2)
|
||||
# # import ipdb; ipdb.set_trace()
|
||||
|
||||
# rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, sdfs_coarse, depths_coarse, normals_coarse, dists, ray_directions, rendering_options, bgcolor)
|
||||
rgb_final, depth_final, weights, normal_final = self.ray_marcher(colors_coarse, sdfs_coarse, depths_coarse, sdf_grad.reshape(*normals_coarse.shape), ray_directions, rendering_options, bgcolor, normals_coarse)
|
||||
# import ipdb; ipdb.set_trace()
|
||||
|
||||
return rgb_final, depth_final, weights.sum(2), sdf_grad, normal_final
|
||||
|
||||
def run_model(self, planes, decoder, sample_coordinates, sample_directions, options):
|
||||
plane_axes = self.plane_axes.to(planes.device)
|
||||
|
||||
out = decoder(sample_directions, sample_coordinates, plane_axes, planes, options)
|
||||
# if options.get('density_noise', 0) > 0:
|
||||
# out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise']
|
||||
return out
|
||||
|
||||
def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options):
|
||||
out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options)
|
||||
out['sigma'] = self.activation_factory(options)(out['sigma'])
|
||||
return out
|
||||
|
||||
def sort_samples(self, all_depths, all_colors, all_densities):
|
||||
_, indices = torch.sort(all_depths, dim=-2)
|
||||
all_depths = torch.gather(all_depths, -2, indices)
|
||||
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
|
||||
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
|
||||
return all_depths, all_colors, all_densities
|
||||
|
||||
def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2):
|
||||
all_depths = torch.cat([depths1, depths2], dim = -2)
|
||||
all_colors = torch.cat([colors1, colors2], dim = -2)
|
||||
all_densities = torch.cat([densities1, densities2], dim = -2)
|
||||
|
||||
_, indices = torch.sort(all_depths, dim=-2)
|
||||
all_depths = torch.gather(all_depths, -2, indices)
|
||||
all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1]))
|
||||
all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1))
|
||||
|
||||
return all_depths, all_colors, all_densities
|
||||
|
||||
def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False):
|
||||
"""
|
||||
Return depths of approximately uniformly spaced samples along rays.
|
||||
"""
|
||||
N, M, _ = ray_origins.shape
|
||||
if disparity_space_sampling:
|
||||
depths_coarse = torch.linspace(0,
|
||||
1,
|
||||
depth_resolution,
|
||||
device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
|
||||
depth_delta = 1/(depth_resolution - 1)
|
||||
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
|
||||
depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse)
|
||||
else:
|
||||
if type(ray_start) == torch.Tensor:
|
||||
depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3)
|
||||
depth_delta = (ray_end - ray_start) / (depth_resolution - 1)
|
||||
depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None]
|
||||
else:
|
||||
depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1)
|
||||
depth_delta = (ray_end - ray_start)/(depth_resolution - 1)
|
||||
depths_coarse += torch.rand_like(depths_coarse) * depth_delta
|
||||
|
||||
return depths_coarse
|
||||
|
||||
def sample_importance(self, z_vals, weights, N_importance):
|
||||
"""
|
||||
Return depths of importance sampled points along rays. See NeRF importance sampling for more.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
batch_size, num_rays, samples_per_ray, _ = z_vals.shape
|
||||
|
||||
z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray)
|
||||
weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher
|
||||
|
||||
# smooth weights
|
||||
weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1).float(), 2, 1, padding=1)
|
||||
weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze()
|
||||
weights = weights + 0.01
|
||||
|
||||
z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:])
|
||||
importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1],
|
||||
N_importance).detach().reshape(batch_size, num_rays, N_importance, 1)
|
||||
return importance_z_vals
|
||||
|
||||
def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5):
|
||||
"""
|
||||
Sample @N_importance samples from @bins with distribution defined by @weights.
|
||||
Inputs:
|
||||
bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2"
|
||||
weights: (N_rays, N_samples_)
|
||||
N_importance: the number of samples to draw from the distribution
|
||||
det: deterministic or not
|
||||
eps: a small number to prevent division by zero
|
||||
Outputs:
|
||||
samples: the sampled samples
|
||||
"""
|
||||
N_rays, N_samples_ = weights.shape
|
||||
weights = weights + eps # prevent division by zero (don't do inplace op!)
|
||||
pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_)
|
||||
cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function
|
||||
cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1)
|
||||
# padded to 0~1 inclusive
|
||||
|
||||
if det:
|
||||
u = torch.linspace(0, 1, N_importance, device=bins.device)
|
||||
u = u.expand(N_rays, N_importance)
|
||||
else:
|
||||
u = torch.rand(N_rays, N_importance, device=bins.device)
|
||||
u = u.contiguous()
|
||||
|
||||
inds = torch.searchsorted(cdf, u, right=True)
|
||||
below = torch.clamp_min(inds-1, 0)
|
||||
above = torch.clamp_max(inds, N_samples_)
|
||||
|
||||
inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance)
|
||||
cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2)
|
||||
bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2)
|
||||
|
||||
denom = cdf_g[...,1]-cdf_g[...,0]
|
||||
denom[denom<eps] = 1 # denom equals 0 means a bin has weight 0, in which case it will not be sampled
|
||||
# anyway, therefore any value for it is fine (set to 1 here)
|
||||
|
||||
samples = bins_g[...,0] + (u-cdf_g[...,0])/denom * (bins_g[...,1]-bins_g[...,0])
|
||||
return samples
|
0
svrm/ldm/modules/translator/__init__.py
Normal file
127
svrm/ldm/modules/translator/img_to_triplane.py
Normal file
@ -0,0 +1,127 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..attention import ImgToTriplaneTransformer
|
||||
import math
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class ImgToTriplaneModel(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
:param in_channels: channels in the input Tensor.
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param num_classes: if specified (as an int), then this model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pos_emb_size=32,
|
||||
pos_emb_dim=1024,
|
||||
cam_cond_dim=20,
|
||||
n_heads=16,
|
||||
d_head=64,
|
||||
depth=16,
|
||||
context_dim=768,
|
||||
triplane_dim=80,
|
||||
upsample_time=1,
|
||||
use_fp16=False,
|
||||
use_bf16=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.pos_emb_size = pos_emb_size
|
||||
self.pos_emb_dim = pos_emb_dim
|
||||
|
||||
# init embedding
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, 3 * pos_emb_size * pos_emb_size, pos_emb_dim))
|
||||
# TODO initialize pos_emb with a Gaussian random of zero-mean and std of 1/sqrt(1024).
|
||||
|
||||
# build image to triplane decoder
|
||||
self.img_to_triplane_decoder = ImgToTriplaneTransformer(
|
||||
query_dim=pos_emb_dim, n_heads=n_heads,
|
||||
d_head=d_head, depth=depth, context_dim=context_dim,
|
||||
triplane_size=pos_emb_size,
|
||||
)
|
||||
|
||||
self.is_conv_upsampler = False
|
||||
# build upsampler
|
||||
self.triplane_dim = triplane_dim
|
||||
if self.is_conv_upsampler:
|
||||
upsamplers = []
|
||||
for i in range(upsample_time):
|
||||
if i == 0:
|
||||
upsampler = nn.ConvTranspose2d(in_channels=pos_emb_dim, out_channels=triplane_dim,
|
||||
kernel_size=2, stride=2,
|
||||
padding=0, output_padding=0)
|
||||
upsamplers.append(upsampler)
|
||||
else:
|
||||
upsampler = nn.ConvTranspose2d(in_channels=triplane_dim, out_channels=triplane_dim,
|
||||
kernel_size=2, stride=2,
|
||||
padding=0, output_padding=0)
|
||||
upsamplers.append(upsampler)
|
||||
if upsamplers:
|
||||
self.upsampler = nn.Sequential(*upsamplers)
|
||||
else:
|
||||
self.upsampler = nn.Conv2d(in_channels=pos_emb_dim, out_channels=triplane_dim,
|
||||
kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
self.upsample_ratio = 4
|
||||
self.upsampler = nn.Linear(in_features=pos_emb_dim, out_features=triplane_dim*(self.upsample_ratio**2))
|
||||
|
||||
|
||||
|
||||
def forward(self, x, cam_cond=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param context: conditioning plugged in via crossattn
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
|
||||
B = x.shape[0]
|
||||
h = self.pos_emb.expand(B, -1, -1)
|
||||
context = x
|
||||
|
||||
h = self.img_to_triplane_decoder(h, context=context)
|
||||
|
||||
h = h.view(B * 3, self.pos_emb_size, self.pos_emb_size, self.pos_emb_dim)
|
||||
if self.is_conv_upsampler:
|
||||
h = rearrange(h, 'b h w c -> b c h w')
|
||||
h = self.upsampler(h)
|
||||
h = rearrange(h, '(b d) c h w-> b d c h w', d=3)
|
||||
h = h.type(x.dtype)
|
||||
return h
|
||||
else:
|
||||
h = self.upsampler(h) #[b, h, w, triplane_dim*4]
|
||||
b, height, width, _ = h.shape
|
||||
h = h.view(b, height, width, self.triplane_dim, self.upsample_ratio, self.upsample_ratio) #[b, h, w, triplane_dim, 2, 2]
|
||||
h = h.permute(0,3,1,4,2,5).contiguous() #[b, triplane_dim, h, 2, w, 2]
|
||||
h = h.view(b, self.triplane_dim, height*self.upsample_ratio, width*self.upsample_ratio)
|
||||
h = rearrange(h, '(b d) c h w-> b d c h w', d=3)
|
||||
h = h.type(x.dtype)
|
||||
return h
|
641
svrm/ldm/modules/x_transformer.py
Normal file
@ -0,0 +1,641 @@
|
||||
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
from collections import namedtuple
|
||||
from einops import rearrange, repeat, reduce
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple('Intermediates', [
|
||||
'pre_softmax_attn',
|
||||
'post_softmax_attn'
|
||||
])
|
||||
|
||||
LayerIntermediates = namedtuple('Intermediates', [
|
||||
'hiddens',
|
||||
'attn_intermediates'
|
||||
])
|
||||
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
self.init_()
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.emb.weight, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
n = torch.arange(x.shape[1], device=x.device)
|
||||
return self.emb(n)[None, :, :]
|
||||
|
||||
|
||||
class FixedPositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x, seq_dim=1, offset=0):
|
||||
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
||||
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
||||
return emb[None, :, :]
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def always(val):
|
||||
def inner(*args, **kwargs):
|
||||
return val
|
||||
return inner
|
||||
|
||||
|
||||
def not_equals(val):
|
||||
def inner(x):
|
||||
return x != val
|
||||
return inner
|
||||
|
||||
|
||||
def equals(val):
|
||||
def inner(x):
|
||||
return x == val
|
||||
return inner
|
||||
|
||||
|
||||
def max_neg_value(tensor):
|
||||
return -torch.finfo(tensor.dtype).max
|
||||
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
# classes
|
||||
class Scale(nn.Module):
|
||||
def __init__(self, value, fn):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.value, *rest)
|
||||
|
||||
|
||||
class Rezero(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.g = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.g, *rest)
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-8):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def forward(self, x, residual):
|
||||
return x + residual
|
||||
|
||||
|
||||
class GRUGating(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gru = nn.GRUCell(dim, dim)
|
||||
|
||||
def forward(self, x, residual):
|
||||
gated_output = self.gru(
|
||||
rearrange(x, 'b n d -> (b n) d'),
|
||||
rearrange(residual, 'b n d -> (b n) d')
|
||||
)
|
||||
|
||||
return gated_output.reshape_as(x)
|
||||
|
||||
|
||||
# feedforward
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# attention.
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=DEFAULT_DIM_HEAD,
|
||||
heads=8,
|
||||
causal=False,
|
||||
mask=None,
|
||||
talking_heads=False,
|
||||
sparse_topk=None,
|
||||
use_entmax15=False,
|
||||
num_mem_kv=0,
|
||||
dropout=0.,
|
||||
on_attn=False
|
||||
):
|
||||
super().__init__()
|
||||
if use_entmax15:
|
||||
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.causal = causal
|
||||
self.mask = mask
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# talking heads
|
||||
self.talking_heads = talking_heads
|
||||
if talking_heads:
|
||||
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
# explicit topk sparse attention
|
||||
self.sparse_topk = sparse_topk
|
||||
|
||||
# entmax
|
||||
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
||||
self.attn_fn = F.softmax
|
||||
|
||||
# add memory key / values
|
||||
self.num_mem_kv = num_mem_kv
|
||||
if num_mem_kv > 0:
|
||||
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
|
||||
# attention on attention
|
||||
self.attn_on_attn = on_attn
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
rel_pos=None,
|
||||
sinusoidal_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None
|
||||
):
|
||||
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
|
||||
kv_input = default(context, x)
|
||||
|
||||
q_input = x
|
||||
k_input = kv_input
|
||||
v_input = kv_input
|
||||
|
||||
if exists(mem):
|
||||
k_input = torch.cat((mem, k_input), dim=-2)
|
||||
v_input = torch.cat((mem, v_input), dim=-2)
|
||||
|
||||
if exists(sinusoidal_emb):
|
||||
# in shortformer, the query would start at a position offset depending on the past cached memory
|
||||
offset = k_input.shape[-2] - q_input.shape[-2]
|
||||
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
|
||||
k_input = k_input + sinusoidal_emb(k_input)
|
||||
|
||||
q = self.to_q(q_input)
|
||||
k = self.to_k(k_input)
|
||||
v = self.to_v(v_input)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
||||
|
||||
input_mask = None
|
||||
if any(map(exists, (mask, context_mask))):
|
||||
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
||||
k_mask = q_mask if not exists(context) else context_mask
|
||||
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
|
||||
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
||||
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
||||
input_mask = q_mask * k_mask
|
||||
|
||||
if self.num_mem_kv > 0:
|
||||
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
|
||||
k = torch.cat((mem_k, k), dim=-2)
|
||||
v = torch.cat((mem_v, v), dim=-2)
|
||||
if exists(input_mask):
|
||||
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
mask_value = max_neg_value(dots)
|
||||
|
||||
if exists(prev_attn):
|
||||
dots = dots + prev_attn
|
||||
|
||||
pre_softmax_attn = dots
|
||||
|
||||
if talking_heads:
|
||||
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
||||
|
||||
if exists(rel_pos):
|
||||
dots = rel_pos(dots)
|
||||
|
||||
if exists(input_mask):
|
||||
dots.masked_fill_(~input_mask, mask_value)
|
||||
del input_mask
|
||||
|
||||
if self.causal:
|
||||
i, j = dots.shape[-2:]
|
||||
r = torch.arange(i, device=device)
|
||||
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
|
||||
mask = F.pad(mask, (j - i, 0), value=False)
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
||||
top, _ = dots.topk(self.sparse_topk, dim=-1)
|
||||
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
|
||||
mask = dots < vk
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = self.attn_fn(dots, dim=-1)
|
||||
post_softmax_attn = attn
|
||||
|
||||
attn = self.dropout(attn)
|
||||
|
||||
if talking_heads:
|
||||
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
intermediates = Intermediates(
|
||||
pre_softmax_attn=pre_softmax_attn,
|
||||
post_softmax_attn=post_softmax_attn
|
||||
)
|
||||
|
||||
return self.to_out(out), intermediates
|
||||
|
||||
|
||||
class AttentionLayers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
pre_norm=True,
|
||||
gate_residual=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
||||
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
|
||||
|
||||
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.has_pos_emb = position_infused_attn
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||
self.rotary_pos_emb = always(None)
|
||||
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||
self.rel_pos = None
|
||||
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
self.residual_attn = residual_attn
|
||||
self.cross_residual_attn = cross_residual_attn
|
||||
|
||||
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
||||
norm_class = RMSNorm if use_rmsnorm else norm_class
|
||||
norm_fn = partial(norm_class, dim)
|
||||
|
||||
norm_fn = nn.Identity if use_rezero else norm_fn
|
||||
branch_fn = Rezero if use_rezero else None
|
||||
|
||||
if cross_attend and not only_cross:
|
||||
default_block = ('a', 'c', 'f')
|
||||
elif cross_attend and only_cross:
|
||||
default_block = ('c', 'f')
|
||||
else:
|
||||
default_block = ('a', 'f')
|
||||
|
||||
if macaron:
|
||||
default_block = ('f',) + default_block
|
||||
|
||||
if exists(custom_layers):
|
||||
layer_types = custom_layers
|
||||
elif exists(par_ratio):
|
||||
par_depth = depth * len(default_block)
|
||||
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
||||
default_block = tuple(filter(not_equals('f'), default_block))
|
||||
par_attn = par_depth // par_ratio
|
||||
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
||||
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
||||
par_block = default_block + ('f',) * (par_width - len(default_block))
|
||||
par_head = par_block * par_attn
|
||||
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
||||
elif exists(sandwich_coef):
|
||||
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
||||
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
||||
else:
|
||||
layer_types = default_block * depth
|
||||
|
||||
self.layer_types = layer_types
|
||||
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
||||
|
||||
for layer_type in self.layer_types:
|
||||
if layer_type == 'a':
|
||||
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
||||
elif layer_type == 'c':
|
||||
layer = Attention(dim, heads=heads, **attn_kwargs)
|
||||
elif layer_type == 'f':
|
||||
layer = FeedForward(dim, **ff_kwargs)
|
||||
layer = layer if not macaron else Scale(0.5, layer)
|
||||
else:
|
||||
raise Exception(f'invalid layer type {layer_type}')
|
||||
|
||||
if isinstance(layer, Attention) and exists(branch_fn):
|
||||
layer = branch_fn(layer)
|
||||
|
||||
if gate_residual:
|
||||
residual_fn = GRUGating(dim)
|
||||
else:
|
||||
residual_fn = Residual()
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
norm_fn(),
|
||||
layer,
|
||||
residual_fn
|
||||
]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False
|
||||
):
|
||||
hiddens = []
|
||||
intermediates = []
|
||||
prev_attn = None
|
||||
prev_cross_attn = None
|
||||
|
||||
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
||||
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||
is_last = ind == (len(self.layers) - 1)
|
||||
|
||||
if layer_type == 'a':
|
||||
hiddens.append(x)
|
||||
layer_mem = mems.pop(0)
|
||||
|
||||
residual = x
|
||||
|
||||
if self.pre_norm:
|
||||
x = norm(x)
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
|
||||
prev_attn=prev_attn, mem=layer_mem)
|
||||
elif layer_type == 'c':
|
||||
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
|
||||
elif layer_type == 'f':
|
||||
out = block(x)
|
||||
|
||||
x = residual_fn(out, residual)
|
||||
|
||||
if layer_type in ('a', 'c'):
|
||||
intermediates.append(inter)
|
||||
|
||||
if layer_type == 'a' and self.residual_attn:
|
||||
prev_attn = inter.pre_softmax_attn
|
||||
elif layer_type == 'c' and self.cross_residual_attn:
|
||||
prev_cross_attn = inter.pre_softmax_attn
|
||||
|
||||
if not self.pre_norm and not is_last:
|
||||
x = norm(x)
|
||||
|
||||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens=hiddens,
|
||||
attn_intermediates=intermediates
|
||||
)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(AttentionLayers):
|
||||
def __init__(self, **kwargs):
|
||||
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
||||
super().__init__(causal=False, **kwargs)
|
||||
|
||||
|
||||
|
||||
class TransformerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_tokens,
|
||||
max_seq_len,
|
||||
attn_layers,
|
||||
emb_dim=None,
|
||||
max_mem_len=0.,
|
||||
emb_dropout=0.,
|
||||
num_memory_tokens=None,
|
||||
tie_embedding=False,
|
||||
use_pos_emb=True
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
||||
|
||||
dim = attn_layers.dim
|
||||
emb_dim = default(emb_dim, dim)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_mem_len = max_mem_len
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
self.attn_layers = attn_layers
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.init_()
|
||||
|
||||
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||
|
||||
# memory tokens (like [cls]) from Memory Transformers paper
|
||||
num_memory_tokens = default(num_memory_tokens, 0)
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||
|
||||
# let funnel encoder know number of memory tokens, if specified
|
||||
if hasattr(attn_layers, 'num_memory_tokens'):
|
||||
attn_layers.num_memory_tokens = num_memory_tokens
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.token_emb.weight, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embeddings=False,
|
||||
mask=None,
|
||||
return_mems=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
**kwargs
|
||||
):
|
||||
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
||||
x = self.token_emb(x)
|
||||
x += self.pos_emb(x)
|
||||
x = self.emb_dropout(x)
|
||||
|
||||
x = self.project_emb(x)
|
||||
|
||||
if num_mem > 0:
|
||||
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
|
||||
x = torch.cat((mem, x), dim=1)
|
||||
|
||||
# auto-handle masking after appending memory tokens
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (num_mem, 0), value=True)
|
||||
|
||||
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
||||
x = self.norm(x)
|
||||
|
||||
mem, x = x[:, :num_mem], x[:, num_mem:]
|
||||
|
||||
out = self.to_logits(x) if not return_embeddings else x
|
||||
|
||||
if return_mems:
|
||||
hiddens = intermediates.hiddens
|
||||
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
|
||||
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
||||
return out, new_mems
|
||||
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
return out, attn_maps
|
||||
|
||||
return out
|
||||
|
252
svrm/ldm/util.py
Normal file
@ -0,0 +1,252 @@
|
||||
import os
|
||||
import importlib
|
||||
from inspect import isfunction
|
||||
import cv2
|
||||
import time
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from torch import optim
|
||||
import torchvision
|
||||
|
||||
|
||||
def pil_rectangle_crop(im):
|
||||
width, height = im.size # Get dimensions
|
||||
|
||||
if width <= height:
|
||||
left = 0
|
||||
right = width
|
||||
top = (height - width)/2
|
||||
bottom = (height + width)/2
|
||||
else:
|
||||
|
||||
top = 0
|
||||
bottom = height
|
||||
left = (width - height) / 2
|
||||
bottom = (width + height) / 2
|
||||
|
||||
# Crop the center of the image
|
||||
im = im.crop((left, top, right, bottom))
|
||||
return im
|
||||
|
||||
|
||||
def add_margin(pil_img, color, size=256):
|
||||
width, height = pil_img.size
|
||||
result = Image.new(pil_img.mode, (size, size), color)
|
||||
result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
|
||||
return result
|
||||
|
||||
|
||||
def load_and_preprocess(interface, input_im):
|
||||
'''
|
||||
:param input_im (PIL Image).
|
||||
:return image (H, W, 3) array in [0, 1].
|
||||
'''
|
||||
# See https://github.com/Ir1d/image-background-remove-tool
|
||||
image = input_im.convert('RGB')
|
||||
|
||||
image_without_background = interface([image])[0]
|
||||
image_without_background = np.array(image_without_background)
|
||||
est_seg = image_without_background > 127
|
||||
image = np.array(image)
|
||||
foreground = est_seg[:, : , -1].astype(np.bool_)
|
||||
image[~foreground] = [255., 255., 255.]
|
||||
x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8))
|
||||
image = image[y:y+h, x:x+w, :]
|
||||
image = Image.fromarray(np.array(image))
|
||||
|
||||
# resize image such that long edge is 512
|
||||
image.thumbnail([200, 200], Image.Resampling.LANCZOS)
|
||||
image = add_margin(image, (255, 255, 255), size=256)
|
||||
image = np.array(image)
|
||||
return image
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x,torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
||||
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
||||
ema_power=1., param_names=()):
|
||||
"""AdamW that saves EMA versions of the parameters."""
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= ema_decay <= 1.0:
|
||||
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
||||
ema_power=ema_power, param_names=param_names)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
state_sums = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
beta1, beta2 = group['betas']
|
||||
ema_decay = group['ema_decay']
|
||||
ema_power = group['ema_power']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError('AdamW does not support sparse gradients')
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of parameter values
|
||||
state['param_exp_avg'] = p.detach().float().clone()
|
||||
|
||||
exp_avgs.append(state['exp_avg'])
|
||||
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||
ema_params_with_grad.append(state['param_exp_avg'])
|
||||
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
||||
|
||||
# update the steps for each param group update
|
||||
state['step'] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
optim._functional.adamw(params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
maximize=False)
|
||||
|
||||
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
||||
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
||||
|
||||
return loss
|
538
svrm/ldm/utils/ops.py
Normal file
@ -0,0 +1,538 @@
|
||||
import os
|
||||
import math
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||
from igl import fast_winding_number_for_meshes, point_mesh_squared_distance, read_obj
|
||||
|
||||
from .typing import *
|
||||
|
||||
|
||||
def get_rank():
|
||||
# SLURM_PROCID can be set even if SLURM is not managing the multiprocessing,
|
||||
# therefore LOCAL_RANK needs to be checked first
|
||||
rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK")
|
||||
for key in rank_keys:
|
||||
rank = os.environ.get(key)
|
||||
if rank is not None:
|
||||
return int(rank)
|
||||
return 0
|
||||
|
||||
def dot(x, y):
|
||||
return torch.sum(x * y, -1, keepdim=True)
|
||||
|
||||
|
||||
def reflect(x, n):
|
||||
return 2 * dot(x, n) * n - x
|
||||
|
||||
|
||||
ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]]
|
||||
|
||||
|
||||
def scale_tensor(
|
||||
dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale
|
||||
):
|
||||
if inp_scale is None:
|
||||
inp_scale = (0, 1)
|
||||
if tgt_scale is None:
|
||||
tgt_scale = (0, 1)
|
||||
if isinstance(tgt_scale, Tensor):
|
||||
assert dat.shape[-1] == tgt_scale.shape[-1]
|
||||
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
|
||||
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
|
||||
return dat
|
||||
|
||||
|
||||
class _TruncExp(Function): # pylint: disable=abstract-method
|
||||
# Implementation from torch-ngp:
|
||||
# https://github.com/ashawkey/torch-ngp/blob/93b08a0d4ec1cc6e69d85df7f0acdfb99603b628/activation.py
|
||||
@staticmethod
|
||||
@custom_fwd(cast_inputs=torch.float32)
|
||||
def forward(ctx, x): # pylint: disable=arguments-differ
|
||||
ctx.save_for_backward(x)
|
||||
return torch.exp(x)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, g): # pylint: disable=arguments-differ
|
||||
x = ctx.saved_tensors[0]
|
||||
return g * torch.exp(torch.clamp(x, max=15))
|
||||
|
||||
|
||||
class SpecifyGradient(Function):
|
||||
# Implementation from stable-dreamfusion
|
||||
# https://github.com/ashawkey/stable-dreamfusion
|
||||
@staticmethod
|
||||
@custom_fwd
|
||||
def forward(ctx, input_tensor, gt_grad):
|
||||
ctx.save_for_backward(gt_grad)
|
||||
# we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward.
|
||||
return torch.ones([1], device=input_tensor.device, dtype=input_tensor.dtype)
|
||||
|
||||
@staticmethod
|
||||
@custom_bwd
|
||||
def backward(ctx, grad_scale):
|
||||
(gt_grad,) = ctx.saved_tensors
|
||||
gt_grad = gt_grad * grad_scale
|
||||
return gt_grad, None
|
||||
|
||||
|
||||
trunc_exp = _TruncExp.apply
|
||||
|
||||
|
||||
def get_activation(name) -> Callable:
|
||||
if name is None:
|
||||
return lambda x: x
|
||||
name = name.lower()
|
||||
if name == "none":
|
||||
return lambda x: x
|
||||
elif name == "lin2srgb":
|
||||
return lambda x: torch.where(
|
||||
x > 0.0031308,
|
||||
torch.pow(torch.clamp(x, min=0.0031308), 1.0 / 2.4) * 1.055 - 0.055,
|
||||
12.92 * x,
|
||||
).clamp(0.0, 1.0)
|
||||
elif name == "exp":
|
||||
return lambda x: torch.exp(x)
|
||||
elif name == "shifted_exp":
|
||||
return lambda x: torch.exp(x - 1.0)
|
||||
elif name == "trunc_exp":
|
||||
return trunc_exp
|
||||
elif name == "shifted_trunc_exp":
|
||||
return lambda x: trunc_exp(x - 1.0)
|
||||
elif name == "sigmoid":
|
||||
return lambda x: torch.sigmoid(x)
|
||||
elif name == "tanh":
|
||||
return lambda x: torch.tanh(x)
|
||||
elif name == "shifted_softplus":
|
||||
return lambda x: F.softplus(x - 1.0)
|
||||
elif name == "scale_-11_01":
|
||||
return lambda x: x * 0.5 + 0.5
|
||||
else:
|
||||
try:
|
||||
return getattr(F, name)
|
||||
except AttributeError:
|
||||
raise ValueError(f"Unknown activation function: {name}")
|
||||
|
||||
|
||||
def chunk_batch(func: Callable, chunk_size: int, triplane=None, *args, **kwargs) -> Any:
|
||||
if chunk_size <= 0:
|
||||
return func(*args, **kwargs)
|
||||
B = None
|
||||
for arg in list(args) + list(kwargs.values()):
|
||||
if isinstance(arg, torch.Tensor):
|
||||
B = arg.shape[0]
|
||||
break
|
||||
assert (
|
||||
B is not None
|
||||
), "No tensor found in args or kwargs, cannot determine batch size."
|
||||
out = defaultdict(list)
|
||||
out_type = None
|
||||
# max(1, B) to support B == 0
|
||||
for i in range(0, max(1, B), chunk_size):
|
||||
if triplane is not None:
|
||||
out_chunk = func(triplane=triplane,
|
||||
*[
|
||||
arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
||||
for arg in args
|
||||
],
|
||||
**{
|
||||
k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
||||
for k, arg in kwargs.items()
|
||||
},
|
||||
)
|
||||
else:
|
||||
out_chunk = func(
|
||||
*[
|
||||
arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
||||
for arg in args
|
||||
],
|
||||
**{
|
||||
k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg
|
||||
for k, arg in kwargs.items()
|
||||
},
|
||||
)
|
||||
if out_chunk is None:
|
||||
continue
|
||||
out_type = type(out_chunk)
|
||||
if isinstance(out_chunk, torch.Tensor):
|
||||
out_chunk = {0: out_chunk}
|
||||
elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list):
|
||||
chunk_length = len(out_chunk)
|
||||
out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)}
|
||||
elif isinstance(out_chunk, dict):
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}."
|
||||
)
|
||||
exit(1)
|
||||
for k, v in out_chunk.items():
|
||||
v = v if torch.is_grad_enabled() else v.detach()
|
||||
out[k].append(v)
|
||||
|
||||
if out_type is None:
|
||||
return None
|
||||
|
||||
out_merged: Dict[Any, Optional[torch.Tensor]] = {}
|
||||
for k, v in out.items():
|
||||
if all([vv is None for vv in v]):
|
||||
# allow None in return value
|
||||
out_merged[k] = None
|
||||
elif all([isinstance(vv, torch.Tensor) for vv in v]):
|
||||
out_merged[k] = torch.cat(v, dim=0)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}"
|
||||
)
|
||||
|
||||
if out_type is torch.Tensor:
|
||||
return out_merged[0]
|
||||
elif out_type in [tuple, list]:
|
||||
return out_type([out_merged[i] for i in range(chunk_length)])
|
||||
elif out_type is dict:
|
||||
return out_merged
|
||||
|
||||
|
||||
def get_ray_directions(
|
||||
H: int,
|
||||
W: int,
|
||||
focal: Union[float, Tuple[float, float]],
|
||||
principal: Optional[Tuple[float, float]] = None,
|
||||
use_pixel_centers: bool = True,
|
||||
) -> Float[Tensor, "H W 3"]:
|
||||
"""
|
||||
Get ray directions for all pixels in camera coordinate.
|
||||
Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
|
||||
ray-tracing-generating-camera-rays/standard-coordinate-systems
|
||||
|
||||
Inputs:
|
||||
H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers
|
||||
Outputs:
|
||||
directions: (H, W, 3), the direction of the rays in camera coordinate
|
||||
"""
|
||||
pixel_center = 0.5 if use_pixel_centers else 0
|
||||
|
||||
if isinstance(focal, float):
|
||||
fx, fy = focal, focal
|
||||
cx, cy = W / 2, H / 2
|
||||
else:
|
||||
fx, fy = focal
|
||||
assert principal is not None
|
||||
cx, cy = principal
|
||||
|
||||
i, j = torch.meshgrid(
|
||||
torch.arange(W, dtype=torch.float32) + pixel_center,
|
||||
torch.arange(H, dtype=torch.float32) + pixel_center,
|
||||
indexing="xy",
|
||||
)
|
||||
|
||||
directions: Float[Tensor, "H W 3"] = torch.stack(
|
||||
[(i - cx) / fx, -(j - cy) / fy, -torch.ones_like(i)], -1
|
||||
)
|
||||
|
||||
return directions
|
||||
|
||||
|
||||
def get_rays(
|
||||
directions: Float[Tensor, "... 3"],
|
||||
c2w: Float[Tensor, "... 4 4"],
|
||||
keepdim=False,
|
||||
noise_scale=0.0,
|
||||
) -> Tuple[Float[Tensor, "... 3"], Float[Tensor, "... 3"]]:
|
||||
# Rotate ray directions from camera coordinate to the world coordinate
|
||||
assert directions.shape[-1] == 3
|
||||
|
||||
if directions.ndim == 2: # (N_rays, 3)
|
||||
if c2w.ndim == 2: # (4, 4)
|
||||
c2w = c2w[None, :, :]
|
||||
assert c2w.ndim == 3 # (N_rays, 4, 4) or (1, 4, 4)
|
||||
rays_d = (directions[:, None, :] * c2w[:, :3, :3]).sum(-1) # (N_rays, 3)
|
||||
rays_o = c2w[:, :3, 3].expand(rays_d.shape)
|
||||
elif directions.ndim == 3: # (H, W, 3)
|
||||
assert c2w.ndim in [2, 3]
|
||||
if c2w.ndim == 2: # (4, 4)
|
||||
rays_d = (directions[:, :, None, :] * c2w[None, None, :3, :3]).sum(
|
||||
-1
|
||||
) # (H, W, 3)
|
||||
rays_o = c2w[None, None, :3, 3].expand(rays_d.shape)
|
||||
elif c2w.ndim == 3: # (B, 4, 4)
|
||||
rays_d = (directions[None, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
|
||||
-1
|
||||
) # (B, H, W, 3)
|
||||
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
||||
elif directions.ndim == 4: # (B, H, W, 3)
|
||||
assert c2w.ndim == 3 # (B, 4, 4)
|
||||
rays_d = (directions[:, :, :, None, :] * c2w[:, None, None, :3, :3]).sum(
|
||||
-1
|
||||
) # (B, H, W, 3)
|
||||
rays_o = c2w[:, None, None, :3, 3].expand(rays_d.shape)
|
||||
|
||||
# add camera noise to avoid grid-like artifect
|
||||
# https://github.com/ashawkey/stable-dreamfusion/blob/49c3d4fa01d68a4f027755acf94e1ff6020458cc/nerf/utils.py#L373
|
||||
if noise_scale > 0:
|
||||
rays_o = rays_o + torch.randn(3, device=rays_o.device) * noise_scale
|
||||
rays_d = rays_d + torch.randn(3, device=rays_d.device) * noise_scale
|
||||
|
||||
rays_d = F.normalize(rays_d, dim=-1)
|
||||
if not keepdim:
|
||||
rays_o, rays_d = rays_o.reshape(-1, 3), rays_d.reshape(-1, 3)
|
||||
|
||||
return rays_o, rays_d
|
||||
|
||||
|
||||
def get_projection_matrix(
|
||||
fovy: Float[Tensor, "B"], aspect_wh: float, near: float, far: float
|
||||
) -> Float[Tensor, "B 4 4"]:
|
||||
batch_size = fovy.shape[0]
|
||||
proj_mtx = torch.zeros(batch_size, 4, 4, dtype=torch.float32)
|
||||
proj_mtx[:, 0, 0] = 1.0 / (torch.tan(fovy / 2.0) * aspect_wh)
|
||||
proj_mtx[:, 1, 1] = -1.0 / torch.tan(
|
||||
fovy / 2.0
|
||||
) # add a negative sign here as the y axis is flipped in nvdiffrast output
|
||||
proj_mtx[:, 2, 2] = -(far + near) / (far - near)
|
||||
proj_mtx[:, 2, 3] = -2.0 * far * near / (far - near)
|
||||
proj_mtx[:, 3, 2] = -1.0
|
||||
return proj_mtx
|
||||
|
||||
|
||||
def get_mvp_matrix(
|
||||
c2w: Float[Tensor, "B 4 4"], proj_mtx: Float[Tensor, "B 4 4"]
|
||||
) -> Float[Tensor, "B 4 4"]:
|
||||
# calculate w2c from c2w: R' = Rt, t' = -Rt * t
|
||||
# mathematically equivalent to (c2w)^-1
|
||||
w2c: Float[Tensor, "B 4 4"] = torch.zeros(c2w.shape[0], 4, 4).to(c2w)
|
||||
w2c[:, :3, :3] = c2w[:, :3, :3].permute(0, 2, 1)
|
||||
w2c[:, :3, 3:] = -c2w[:, :3, :3].permute(0, 2, 1) @ c2w[:, :3, 3:]
|
||||
w2c[:, 3, 3] = 1.0
|
||||
# calculate mvp matrix by proj_mtx @ w2c (mv_mtx)
|
||||
mvp_mtx = proj_mtx @ w2c
|
||||
return mvp_mtx
|
||||
|
||||
|
||||
def get_full_projection_matrix(
|
||||
c2w: Float[Tensor, "B 4 4"], proj_mtx: Float[Tensor, "B 4 4"]
|
||||
) -> Float[Tensor, "B 4 4"]:
|
||||
return (c2w.unsqueeze(0).bmm(proj_mtx.unsqueeze(0))).squeeze(0)
|
||||
|
||||
|
||||
# gaussian splatting functions
|
||||
def convert_pose(C2W):
|
||||
flip_yz = torch.eye(4, device=C2W.device)
|
||||
flip_yz[1, 1] = -1
|
||||
flip_yz[2, 2] = -1
|
||||
C2W = torch.matmul(C2W, flip_yz)
|
||||
return C2W
|
||||
|
||||
|
||||
def get_projection_matrix_gaussian(znear, zfar, fovX, fovY, device="cuda"):
|
||||
tanHalfFovY = math.tan((fovY / 2))
|
||||
tanHalfFovX = math.tan((fovX / 2))
|
||||
|
||||
top = tanHalfFovY * znear
|
||||
bottom = -top
|
||||
right = tanHalfFovX * znear
|
||||
left = -right
|
||||
|
||||
P = torch.zeros(4, 4, device=device)
|
||||
|
||||
z_sign = 1.0
|
||||
|
||||
P[0, 0] = 2.0 * znear / (right - left)
|
||||
P[1, 1] = 2.0 * znear / (top - bottom)
|
||||
P[0, 2] = (right + left) / (right - left)
|
||||
P[1, 2] = (top + bottom) / (top - bottom)
|
||||
P[3, 2] = z_sign
|
||||
P[2, 2] = z_sign * zfar / (zfar - znear)
|
||||
P[2, 3] = -(zfar * znear) / (zfar - znear)
|
||||
return P
|
||||
|
||||
|
||||
def get_fov_gaussian(P):
|
||||
tanHalfFovX = 1 / P[0, 0]
|
||||
tanHalfFovY = 1 / P[1, 1]
|
||||
fovY = math.atan(tanHalfFovY) * 2
|
||||
fovX = math.atan(tanHalfFovX) * 2
|
||||
return fovX, fovY
|
||||
|
||||
|
||||
def get_cam_info_gaussian(c2w, fovx, fovy, znear, zfar):
|
||||
c2w = convert_pose(c2w)
|
||||
world_view_transform = torch.inverse(c2w)
|
||||
|
||||
world_view_transform = world_view_transform.transpose(0, 1).cuda().float()
|
||||
projection_matrix = (
|
||||
get_projection_matrix_gaussian(znear=znear, zfar=zfar, fovX=fovx, fovY=fovy)
|
||||
.transpose(0, 1)
|
||||
.cuda()
|
||||
)
|
||||
full_proj_transform = (
|
||||
world_view_transform.unsqueeze(0).bmm(projection_matrix.unsqueeze(0))
|
||||
).squeeze(0)
|
||||
camera_center = world_view_transform.inverse()[3, :3]
|
||||
|
||||
return world_view_transform, full_proj_transform, camera_center
|
||||
|
||||
|
||||
def binary_cross_entropy(input, target):
|
||||
"""
|
||||
F.binary_cross_entropy is not numerically stable in mixed-precision training.
|
||||
"""
|
||||
return -(target * torch.log(input) + (1 - target) * torch.log(1 - input)).mean()
|
||||
|
||||
|
||||
def tet_sdf_diff(
|
||||
vert_sdf: Float[Tensor, "Nv 1"], tet_edges: Integer[Tensor, "Ne 2"]
|
||||
) -> Float[Tensor, ""]:
|
||||
sdf_f1x6x2 = vert_sdf[:, 0][tet_edges.reshape(-1)].reshape(-1, 2)
|
||||
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
|
||||
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
||||
sdf_diff = F.binary_cross_entropy_with_logits(
|
||||
sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()
|
||||
) + F.binary_cross_entropy_with_logits(
|
||||
sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()
|
||||
)
|
||||
return sdf_diff
|
||||
|
||||
|
||||
# Implementation from Latent-NeRF
|
||||
# https://github.com/eladrich/latent-nerf/blob/f49ecefcd48972e69a28e3116fe95edf0fac4dc8/src/latent_nerf/models/mesh_utils.py
|
||||
class MeshOBJ:
|
||||
dx = torch.zeros(3).float()
|
||||
dx[0] = 1
|
||||
dy, dz = dx[[1, 0, 2]], dx[[2, 1, 0]]
|
||||
dx, dy, dz = dx[None, :], dy[None, :], dz[None, :]
|
||||
|
||||
def __init__(self, v: np.ndarray, f: np.ndarray):
|
||||
self.v = v
|
||||
self.f = f
|
||||
self.dx, self.dy, self.dz = MeshOBJ.dx, MeshOBJ.dy, MeshOBJ.dz
|
||||
self.v_tensor = torch.from_numpy(self.v)
|
||||
|
||||
vf = self.v[self.f, :]
|
||||
self.f_center = vf.mean(axis=1)
|
||||
self.f_center_tensor = torch.from_numpy(self.f_center).float()
|
||||
|
||||
e1 = vf[:, 1, :] - vf[:, 0, :]
|
||||
e2 = vf[:, 2, :] - vf[:, 0, :]
|
||||
self.face_normals = np.cross(e1, e2)
|
||||
self.face_normals = (
|
||||
self.face_normals / np.linalg.norm(self.face_normals, axis=-1)[:, None]
|
||||
)
|
||||
self.face_normals_tensor = torch.from_numpy(self.face_normals)
|
||||
|
||||
def normalize_mesh(self, target_scale=0.5):
|
||||
verts = self.v
|
||||
|
||||
# Compute center of bounding box
|
||||
# center = torch.mean(torch.column_stack([torch.max(verts, dim=0)[0], torch.min(verts, dim=0)[0]]))
|
||||
center = verts.mean(axis=0)
|
||||
verts = verts - center
|
||||
scale = np.max(np.linalg.norm(verts, axis=1))
|
||||
verts = (verts / scale) * target_scale
|
||||
|
||||
return MeshOBJ(verts, self.f)
|
||||
|
||||
def winding_number(self, query: torch.Tensor):
|
||||
device = query.device
|
||||
shp = query.shape
|
||||
query_np = query.detach().cpu().reshape(-1, 3).numpy()
|
||||
target_alphas = fast_winding_number_for_meshes(
|
||||
self.v.astype(np.float32), self.f, query_np
|
||||
)
|
||||
return torch.from_numpy(target_alphas).reshape(shp[:-1]).to(device)
|
||||
|
||||
def gaussian_weighted_distance(self, query: torch.Tensor, sigma):
|
||||
device = query.device
|
||||
shp = query.shape
|
||||
query_np = query.detach().cpu().reshape(-1, 3).numpy()
|
||||
distances, _, _ = point_mesh_squared_distance(
|
||||
query_np, self.v.astype(np.float32), self.f
|
||||
)
|
||||
distances = torch.from_numpy(distances).reshape(shp[:-1]).to(device)
|
||||
weight = torch.exp(-(distances / (2 * sigma**2)))
|
||||
return weight
|
||||
|
||||
|
||||
def ce_pq_loss(p, q, weight=None):
|
||||
def clamp(v, T=0.0001):
|
||||
return v.clamp(T, 1 - T)
|
||||
|
||||
p = p.view(q.shape)
|
||||
ce = -1 * (p * torch.log(clamp(q)) + (1 - p) * torch.log(clamp(1 - q)))
|
||||
if weight is not None:
|
||||
ce *= weight
|
||||
return ce.sum()
|
||||
|
||||
|
||||
class ShapeLoss(nn.Module):
|
||||
def __init__(self, guide_shape):
|
||||
super().__init__()
|
||||
self.mesh_scale = 0.7
|
||||
self.proximal_surface = 0.3
|
||||
self.delta = 0.2
|
||||
self.shape_path = guide_shape
|
||||
v, _, _, f, _, _ = read_obj(self.shape_path, float)
|
||||
mesh = MeshOBJ(v, f)
|
||||
matrix_rot = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]]) @ np.array(
|
||||
[[0, 0, 1], [0, 1, 0], [-1, 0, 0]]
|
||||
)
|
||||
self.sketchshape = mesh.normalize_mesh(self.mesh_scale)
|
||||
self.sketchshape = MeshOBJ(
|
||||
np.ascontiguousarray(
|
||||
(matrix_rot @ self.sketchshape.v.transpose(1, 0)).transpose(1, 0)
|
||||
),
|
||||
f,
|
||||
)
|
||||
|
||||
def forward(self, xyzs, sigmas):
|
||||
mesh_occ = self.sketchshape.winding_number(xyzs)
|
||||
if self.proximal_surface > 0:
|
||||
weight = 1 - self.sketchshape.gaussian_weighted_distance(
|
||||
xyzs, self.proximal_surface
|
||||
)
|
||||
else:
|
||||
weight = None
|
||||
indicator = (mesh_occ > 0.5).float()
|
||||
nerf_occ = 1 - torch.exp(-self.delta * sigmas)
|
||||
nerf_occ = nerf_occ.clamp(min=0, max=1.1)
|
||||
loss = ce_pq_loss(
|
||||
nerf_occ, indicator, weight=weight
|
||||
) # order is important for CE loss + second argument may not be optimized
|
||||
return loss
|
||||
|
||||
|
||||
def shifted_expotional_decay(a, b, c, r):
|
||||
return a * torch.exp(-b * r) + c
|
||||
|
||||
|
||||
def shifted_cosine_decay(a, b, c, r):
|
||||
return a * torch.cos(b * r + c) + a
|
||||
|
||||
|
||||
def perpendicular_component(x: Float[Tensor, "B C H W"], y: Float[Tensor, "B C H W"]):
|
||||
# get the component of x that is perpendicular to y
|
||||
eps = torch.ones_like(x[:, 0, 0, 0]) * 1e-6
|
||||
return (
|
||||
x
|
||||
- (
|
||||
torch.mul(x, y).sum(dim=[1, 2, 3])
|
||||
/ torch.maximum(torch.mul(y, y).sum(dim=[1, 2, 3]), eps)
|
||||
).view(-1, 1, 1, 1)
|
||||
* y
|
||||
)
|
||||
|
||||
|
||||
def validate_empty_rays(ray_indices, t_start, t_end):
|
||||
if ray_indices.nelement() == 0:
|
||||
print("Warn Empty rays_indices!")
|
||||
ray_indices = torch.LongTensor([0]).to(ray_indices)
|
||||
t_start = torch.Tensor([0]).to(ray_indices)
|
||||
t_end = torch.Tensor([0]).to(ray_indices)
|
||||
return ray_indices, t_start, t_end
|
||||
|
38
svrm/ldm/utils/typing.py
Normal file
@ -0,0 +1,38 @@
|
||||
"""
|
||||
This module contains type annotations for the project, using
|
||||
1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects
|
||||
2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors
|
||||
|
||||
Two types of typing checking can be used:
|
||||
1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode)
|
||||
2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking)
|
||||
"""
|
||||
|
||||
# Basic types
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
NewType,
|
||||
Optional,
|
||||
Sized,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
# Tensor dtype
|
||||
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
|
||||
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
|
||||
|
||||
# PyTorch Tensor type
|
||||
from torch import Tensor
|
||||
|
||||
# Runtime type checking decorator
|
||||
from typeguard import typechecked as typechecker
|
||||
|
91
svrm/ldm/vis_util.py
Normal file
@ -0,0 +1,91 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from PIL import Image
|
||||
import imageio
|
||||
import time
|
||||
import torch
|
||||
from pytorch3d.io import load_objs_as_meshes, load_obj, save_obj
|
||||
from pytorch3d.ops import interpolate_face_attributes
|
||||
from pytorch3d.common.datatypes import Device
|
||||
from pytorch3d.structures import Meshes
|
||||
from pytorch3d.vis.texture_vis import texturesuv_image_matplotlib
|
||||
from pytorch3d.renderer import (
|
||||
look_at_view_transform,
|
||||
FoVPerspectiveCameras,
|
||||
PointLights,
|
||||
DirectionalLights,
|
||||
AmbientLights,
|
||||
Materials,
|
||||
RasterizationSettings,
|
||||
MeshRenderer,
|
||||
MeshRasterizer,
|
||||
SoftPhongShader,
|
||||
TexturesUV,
|
||||
TexturesVertex,
|
||||
camera_position_from_spherical_angles,
|
||||
BlendParams,
|
||||
)
|
||||
|
||||
|
||||
def render(
|
||||
obj_filename,
|
||||
elev=0,
|
||||
azim=0,
|
||||
resolution=512,
|
||||
gif_dst_path='',
|
||||
n_views=120,
|
||||
fps=30,
|
||||
device="cuda:0",
|
||||
rgb=False
|
||||
):
|
||||
'''
|
||||
obj_filename: path to obj file
|
||||
gif_dst_path:
|
||||
if set a path, will render n_views frames, then save it to a gif file
|
||||
if not set, will render single frame, then return PIL.Image instance
|
||||
rgb: if set true, will convert result to rgb image/frame
|
||||
'''
|
||||
# load mesh
|
||||
mesh = load_objs_as_meshes([obj_filename], device=device)
|
||||
meshes = mesh.extend(n_views)
|
||||
|
||||
if gif_dst_path != '':
|
||||
elev = torch.linspace(elev, elev, n_views+1)[:-1]
|
||||
azim = torch.linspace(0, 360, n_views+1)[:-1]
|
||||
|
||||
# prepare R,T then compute cameras
|
||||
R, T = look_at_view_transform(dist=1.5, elev=elev, azim=azim)
|
||||
cameras = FoVPerspectiveCameras(device=device, R=R, T=T, fov=49.1)
|
||||
|
||||
# init pytorch3d renderer instance
|
||||
renderer = MeshRenderer(
|
||||
rasterizer=MeshRasterizer(
|
||||
cameras=cameras,
|
||||
raster_settings=RasterizationSettings(
|
||||
image_size=resolution,
|
||||
blur_radius=0.0,
|
||||
faces_per_pixel=1,
|
||||
),
|
||||
),
|
||||
shader=SoftPhongShader(
|
||||
device=device,
|
||||
cameras=cameras,
|
||||
lights=AmbientLights(device=device),
|
||||
blend_params=BlendParams(background_color=(1.0, 1.0, 1.0)),
|
||||
)
|
||||
)
|
||||
images = renderer(meshes)
|
||||
|
||||
# single frame rendering
|
||||
if gif_dst_path == '':
|
||||
frame = images[0, ..., :3] if rgb else images[0, ...]
|
||||
frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
|
||||
return frame
|
||||
|
||||
# orbit frames rendering
|
||||
with imageio.get_writer(uri=gif_dst_path, mode='I', duration=1. / fps * 1000, loop=0) as writer:
|
||||
for i in range(n_views):
|
||||
frame = images[i, ..., :3] if rgb else images[i, ...]
|
||||
frame = Image.fromarray((frame.cpu().squeeze(0) * 255).numpy().astype("uint8"))
|
||||
writer.append_data(frame)
|
||||
return gif_dst_path
|
150
svrm/predictor.py
Normal file
@ -0,0 +1,150 @@
|
||||
# Open Source Model Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
||||
# The below Model in this distribution may have been modified by THL A29 Limited ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2024 THL A29 Limited.
|
||||
|
||||
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
|
||||
# The below software and/or models in this distribution may have been
|
||||
# modified by THL A29 Limited ("Tencent Modifications").
|
||||
# All Tencent Modifications are Copyright (C) THL A29 Limited.
|
||||
|
||||
# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT
|
||||
# except for the third-party components listed below.
|
||||
# Hunyuan 3D does not impose any additional limitations beyond what is outlined
|
||||
# in the repsective licenses of these third-party components.
|
||||
# Users must comply with all terms and conditions of original licenses of these third-party
|
||||
# components and must ensure that the usage of the third party components adheres to
|
||||
# all relevant laws and regulations.
|
||||
|
||||
# For avoidance of doubts, Hunyuan 3D means the large language models and
|
||||
# their software and algorithms, including trained model weights, parameters (including
|
||||
# optimizer states), machine-learning model code, inference-enabling code, training-enabling code,
|
||||
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
||||
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
||||
|
||||
import os
|
||||
import math
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from PIL import Image, ImageSequence
|
||||
from omegaconf import OmegaConf
|
||||
from torchvision import transforms
|
||||
from safetensors.torch import save_file, load_file
|
||||
from .ldm.util import instantiate_from_config
|
||||
from .ldm.vis_util import render
|
||||
|
||||
class MV23DPredictor(object):
|
||||
def __init__(self, ckpt_path, cfg_path, elevation=15, number_view=60,
|
||||
render_size=256, device="cuda:0") -> None:
|
||||
self.device = device
|
||||
self.elevation = elevation
|
||||
self.number_view = number_view
|
||||
self.render_size = render_size
|
||||
|
||||
self.elevation_list = [0, 0, 0, 0, 0, 0, 0]
|
||||
self.azimuth_list = [0, 60, 120, 180, 240, 300, 0]
|
||||
|
||||
st = time.time()
|
||||
self.model = self.init_model(ckpt_path, cfg_path)
|
||||
print(f"=====> mv23d model init time: {time.time() - st}")
|
||||
|
||||
self.input_view_transform = transforms.Compose([
|
||||
transforms.Resize(504, interpolation=Image.BICUBIC),
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
self.final_input_view_transform = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
||||
|
||||
def init_model(self, ckpt_path, cfg_path):
|
||||
config = OmegaConf.load(cfg_path)
|
||||
model = instantiate_from_config(config.model)
|
||||
|
||||
weights = load_file("./weights/svrm/svrm.safetensors")
|
||||
model.load_state_dict(weights)
|
||||
|
||||
model.to(self.device)
|
||||
model = model.eval()
|
||||
model.render.half()
|
||||
print(f'Load model successfully')
|
||||
return model
|
||||
|
||||
def create_camera_to_world_matrix(self, elevation, azimuth, cam_dis=1.5):
|
||||
# elevation azimuth are radians
|
||||
# Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
|
||||
x = np.cos(elevation) * np.cos(azimuth)
|
||||
y = np.cos(elevation) * np.sin(azimuth)
|
||||
z = np.sin(elevation)
|
||||
|
||||
# Calculate camera position, target, and up vectors
|
||||
camera_pos = np.array([x, y, z]) * cam_dis
|
||||
target = np.array([0, 0, 0])
|
||||
up = np.array([0, 0, 1])
|
||||
|
||||
# Construct view matrix
|
||||
forward = target - camera_pos
|
||||
forward /= np.linalg.norm(forward)
|
||||
right = np.cross(forward, up)
|
||||
right /= np.linalg.norm(right)
|
||||
new_up = np.cross(right, forward)
|
||||
new_up /= np.linalg.norm(new_up)
|
||||
cam2world = np.eye(4)
|
||||
cam2world[:3, :3] = np.array([right, new_up, -forward]).T
|
||||
cam2world[:3, 3] = camera_pos
|
||||
return cam2world
|
||||
|
||||
def refine_mask(self, mask, k=16):
|
||||
mask /= 255.0
|
||||
boder_mask = (mask >= -math.pi / 2.0 / k + 0.5) & (mask <= math.pi / 2.0 / k + 0.5)
|
||||
mask[boder_mask] = 0.5 * np.sin(k * (mask[boder_mask] - 0.5)) + 0.5
|
||||
mask[mask < -math.pi / 2.0 / k + 0.5] = 0.0
|
||||
mask[mask > math.pi / 2.0 / k + 0.5] = 1.0
|
||||
return (mask * 255.0).astype(np.uint8)
|
||||
|
||||
def load_images_and_cameras(self, input_imgs, elevation_list, azimuth_list):
|
||||
input_image_list = []
|
||||
input_cam_list = []
|
||||
for input_view_image, elevation, azimuth in zip(input_imgs, elevation_list, azimuth_list):
|
||||
input_view_image = self.input_view_transform(input_view_image)
|
||||
input_image_list.append(input_view_image)
|
||||
|
||||
input_view_cam_pos = self.create_camera_to_world_matrix(np.radians(elevation), np.radians(azimuth))
|
||||
input_view_cam_intrinsic = np.array([35. / 32, 35. /32, 0.5, 0.5])
|
||||
input_view_cam = torch.from_numpy(
|
||||
np.concatenate([input_view_cam_pos.reshape(-1), input_view_cam_intrinsic], 0)
|
||||
).float()
|
||||
input_cam_list.append(input_view_cam)
|
||||
|
||||
pixels_input = torch.stack(input_image_list, dim=0)
|
||||
input_images = self.final_input_view_transform(pixels_input)
|
||||
input_cams = torch.stack(input_cam_list, dim=0)
|
||||
return input_images, input_cams
|
||||
|
||||
def load_data(self, intput_imgs):
|
||||
assert (6+1) == len(intput_imgs)
|
||||
|
||||
input_images, input_cams = self.load_images_and_cameras(intput_imgs, self.elevation_list, self.azimuth_list)
|
||||
input_cams[-1, :] = 0 # for user input view
|
||||
|
||||
data = {}
|
||||
data["input_view"] = input_images.unsqueeze(0).to(self.device) # 1 4 3 512 512
|
||||
data["input_view_cam"] = input_cams.unsqueeze(0).to(self.device) # 1 4 20
|
||||
return data
|
||||
|
||||
@torch.no_grad()
|
||||
def predict(
|
||||
self,
|
||||
intput_imgs,
|
||||
save_dir = "outputs/",
|
||||
image_input = None,
|
||||
target_face_count = 10000,
|
||||
do_texture_mapping = True,
|
||||
):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
print(save_dir)
|
||||
|
||||
with torch.cuda.amp.autocast():
|
||||
self.model.export_mesh_with_uv(
|
||||
data = self.load_data(intput_imgs),
|
||||
out_dir = save_dir,
|
||||
target_face_count = target_face_count,
|
||||
do_texture_mapping = do_texture_mapping
|
||||
)
|
90
svrm/utils/camera_utils.py
Normal file
@ -0,0 +1,90 @@
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
def compute_extrinsic_matrix(elevation, azimuth, camera_distance):
|
||||
# 将角度转换为弧度
|
||||
elevation_rad = np.radians(elevation)
|
||||
azimuth_rad = np.radians(azimuth)
|
||||
|
||||
R = np.array([
|
||||
[np.cos(azimuth_rad), 0, -np.sin(azimuth_rad)],
|
||||
[0, 1, 0],
|
||||
[np.sin(azimuth_rad), 0, np.cos(azimuth_rad)],
|
||||
], dtype=np.float32)
|
||||
|
||||
R = R @ np.array([
|
||||
[1, 0, 0],
|
||||
[0, np.cos(elevation_rad), -np.sin(elevation_rad)],
|
||||
[0, np.sin(elevation_rad), np.cos(elevation_rad)]
|
||||
], dtype=np.float32)
|
||||
|
||||
# 构建平移矩阵 T (3x1)
|
||||
T = np.array([[camera_distance], [0], [0]], dtype=np.float32)
|
||||
T = R @ T
|
||||
|
||||
# 组合成 4x4 的变换矩阵
|
||||
extrinsic_matrix = np.vstack((np.hstack((R, T)), np.array([[0, 0, 0, 1]], dtype=np.float32)))
|
||||
|
||||
return extrinsic_matrix
|
||||
|
||||
|
||||
def transform_camera_pose(im_pose, ori_pose, new_pose):
|
||||
T = new_pose @ ori_pose.T
|
||||
transformed_poses = []
|
||||
|
||||
for pose in im_pose:
|
||||
transformed_pose = T @ pose
|
||||
transformed_poses.append(transformed_pose)
|
||||
|
||||
return transformed_poses
|
||||
|
||||
def compute_fov(intrinsic_matrix):
|
||||
# 获取内参矩阵中的焦距值
|
||||
fx = intrinsic_matrix[0, 0]
|
||||
fy = intrinsic_matrix[1, 1]
|
||||
|
||||
h, w = intrinsic_matrix[0,2]*2, intrinsic_matrix[1,2]*2
|
||||
|
||||
# 计算水平和垂直方向的FOV值
|
||||
fov_x = 2 * math.atan(w / (2 * fx)) * 180 / math.pi
|
||||
fov_y = 2 * math.atan(h / (2 * fy)) * 180 / math.pi
|
||||
|
||||
return fov_x, fov_y
|
||||
|
||||
|
||||
|
||||
def rotation_matrix_to_quaternion(rotation_matrix):
|
||||
rot = Rotation.from_matrix(rotation_matrix)
|
||||
quaternion = rot.as_quat()
|
||||
return quaternion
|
||||
|
||||
def quaternion_to_rotation_matrix(quaternion):
|
||||
rot = Rotation.from_quat(quaternion)
|
||||
rotation_matrix = rot.as_matrix()
|
||||
return rotation_matrix
|
||||
|
||||
def remap_points(img_size, match, size=512):
|
||||
H, W, _ = img_size
|
||||
|
||||
S = max(W, H)
|
||||
new_W = int(round(W * size / S))
|
||||
new_H = int(round(H * size / S))
|
||||
cx, cy = new_W // 2, new_H // 2
|
||||
|
||||
# 计算变换后的图像中心点坐标
|
||||
halfw, halfh = ((2 * cx) // 16) * 8, ((2 * cy) // 16) * 8
|
||||
|
||||
dw, dh = cx - halfw, cy - halfh
|
||||
|
||||
# 初始化一个新的数组来存储映射回原始图像的点坐标
|
||||
new_match = np.zeros_like(match)
|
||||
|
||||
# 将变换后的点坐标映射回原始图像
|
||||
new_match[:, 0] = (match[:, 0] + dw) / new_W * W
|
||||
new_match[:, 1] = (match[:, 1] + dh) / new_H * H
|
||||
|
||||
#print(dw,new_W,W,dh,new_H,H)
|
||||
|
||||
return new_match
|
||||
|
||||
|
217
svrm/utils/img_utils.py
Normal file
@ -0,0 +1,217 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from skimage.metrics import hausdorff_distance
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
|
||||
def get_input_imgs_path(input_data_dir):
|
||||
path = {}
|
||||
names = ['000', 'ori_000']
|
||||
for name in names:
|
||||
jpg_path = os.path.join(input_data_dir, f"{name}.jpg")
|
||||
png_path = os.path.join(input_data_dir, f"{name}.png")
|
||||
if os.path.exists(jpg_path):
|
||||
path[name] = jpg_path
|
||||
elif os.path.exists(png_path):
|
||||
path[name] = png_path
|
||||
return path
|
||||
|
||||
|
||||
def rgba_to_rgb(image, bg_color=[255, 255, 255]):
|
||||
if image.shape[-1] == 3: return image
|
||||
|
||||
rgba = image.astype(float)
|
||||
rgb = rgba[:, :, :3].copy()
|
||||
alpha = rgba[:, :, 3] / 255.0
|
||||
|
||||
bg = np.ones((image.shape[0], image.shape[1], 3), dtype=np.float32)
|
||||
bg = bg * np.array(bg_color, dtype=np.float32)
|
||||
|
||||
rgb = rgb * alpha[:, :, np.newaxis] + bg * (1 - alpha[:, :, np.newaxis])
|
||||
rgb = rgb.astype(np.uint8)
|
||||
|
||||
return rgb
|
||||
|
||||
|
||||
def resize_with_aspect_ratio(image1, image2, pad_value=[255, 255, 255]):
|
||||
aspect_ratio1 = float(image1.shape[1]) / float(image1.shape[0])
|
||||
aspect_ratio2 = float(image2.shape[1]) / float(image2.shape[0])
|
||||
|
||||
top_pad, bottom_pad, left_pad, right_pad = 0, 0, 0, 0
|
||||
|
||||
if aspect_ratio1 < aspect_ratio2:
|
||||
new_width = (aspect_ratio2 * image1.shape[0])
|
||||
right_pad = left_pad = int((new_width - image1.shape[1]) / 2)
|
||||
else:
|
||||
new_height = (image1.shape[1] / aspect_ratio2)
|
||||
bottom_pad = top_pad = int((new_height - image1.shape[0]) / 2)
|
||||
|
||||
image1_padded = cv2.copyMakeBorder(
|
||||
image1, top_pad, bottom_pad, left_pad, right_pad, cv2.BORDER_CONSTANT, value=pad_value
|
||||
)
|
||||
return image1_padded
|
||||
|
||||
|
||||
def estimate_img_mask(image):
|
||||
# 转换为灰度图像
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# 使用大津法进行阈值分割
|
||||
# _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
||||
# mask_otsu = thresh.astype(bool)
|
||||
# thresh_gray = 240
|
||||
|
||||
# 使用 Canny 边缘检测算法找到边缘
|
||||
edges = cv2.Canny(gray, 20, 50)
|
||||
|
||||
# 使用形态学操作扩展边缘
|
||||
kernel = np.ones((3, 3), np.uint8)
|
||||
edges_dilated = cv2.dilate(edges, kernel, iterations=1)
|
||||
|
||||
contours, _ = cv2.findContours(edges_dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
|
||||
# 创建一个空的 mask
|
||||
mask = np.zeros_like(gray, dtype=np.uint8)
|
||||
|
||||
# 根据轮廓信息填充 mask(使用 thickness=cv2.FILLED 参数)
|
||||
cv2.drawContours(mask, contours, -1, 255, thickness=cv2.FILLED)
|
||||
mask = mask.astype(bool)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def compute_img_diff(img1, img2, matches1, matches1_from_2, vis=False):
|
||||
scale = 0.125
|
||||
gray_trunc_thres = 25 / 255.0
|
||||
|
||||
# Match
|
||||
if matches1.shape[0] > 0:
|
||||
match_scale = np.max(np.ptp(matches1, axis=-1))
|
||||
match_dists = np.sqrt(np.sum((matches1 - matches1_from_2) ** 2, axis=-1))
|
||||
dist_threshold = match_scale * 0.01
|
||||
match_num = np.sum(match_dists <= dist_threshold)
|
||||
match_rate = np.mean(match_dists <= dist_threshold)
|
||||
else:
|
||||
match_num = 0
|
||||
match_rate = 0
|
||||
|
||||
# IOU
|
||||
img1_mask = estimate_img_mask(img1)
|
||||
img2_mask = estimate_img_mask(img2)
|
||||
img_intersection = (img1_mask == 1) & (img2_mask == 1)
|
||||
img_union = (img1_mask == 1) | (img2_mask == 1)
|
||||
intersection = np.sum(img_intersection == 1)
|
||||
union = np.sum(img_union == 1)
|
||||
mask_iou = intersection / union if union != 0 else 0
|
||||
|
||||
# Gray
|
||||
height, width = img1.shape[:2]
|
||||
img1_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)
|
||||
img2_gray = cv2.cvtColor(img2, cv2.COLOR_BGR2GRAY)
|
||||
img1_gray = cv2.GaussianBlur(img1_gray, (7, 7), 0)
|
||||
img2_gray = cv2.GaussianBlur(img2_gray, (7, 7), 0)
|
||||
|
||||
# Gray Diff
|
||||
img1_gray_small = cv2.resize(img1_gray, (int(width * scale), int(height * scale)),
|
||||
interpolation=cv2.INTER_LINEAR) / 255.0
|
||||
img2_gray_small = cv2.resize(img2_gray, (int(width * scale), int(height * scale)),
|
||||
interpolation=cv2.INTER_LINEAR) / 255.0
|
||||
img_gray_small_diff = np.abs(img1_gray_small - img2_gray_small)
|
||||
gray_diff = img_gray_small_diff.sum() / (union * scale) if union != 0 else 1
|
||||
|
||||
img_gray_small_diff_trunc = img_gray_small_diff.copy()
|
||||
img_gray_small_diff_trunc[img_gray_small_diff < gray_trunc_thres] = 0
|
||||
gray_diff_trunc = img_gray_small_diff_trunc.sum() / (union * scale) if union != 0 else 1
|
||||
|
||||
# Edge
|
||||
img1_edge = cv2.Canny(img1_gray, 100, 200)
|
||||
img2_edge = cv2.Canny(img2_gray, 100, 200)
|
||||
bw_edges1 = (img1_edge > 0).astype(bool)
|
||||
bw_edges2 = (img2_edge > 0).astype(bool)
|
||||
hausdorff_dist = hausdorff_distance(bw_edges1, bw_edges2)
|
||||
if vis == True:
|
||||
fig, axs = plt.subplots(1, 4, figsize=(15, 5))
|
||||
axs[0].imshow(img1_gray, cmap='gray')
|
||||
axs[0].set_title('Img1')
|
||||
axs[1].imshow(img2_gray, cmap='gray')
|
||||
axs[1].set_title('Img2')
|
||||
axs[2].imshow(img1_mask)
|
||||
axs[2].set_title('Mask1')
|
||||
axs[3].imshow(img2_mask)
|
||||
axs[3].set_title('Mask2')
|
||||
plt.show()
|
||||
plt.figure()
|
||||
mask_cmp = np.zeros((height, width, 3))
|
||||
mask_cmp[img_intersection, 1] = 1
|
||||
mask_cmp[img_union, 0] = 1
|
||||
plt.imshow(mask_cmp)
|
||||
plt.show()
|
||||
fig, axs = plt.subplots(1, 4, figsize=(15, 5))
|
||||
axs[0].imshow(img1_gray_small, cmap='gray')
|
||||
axs[0].set_title('Img1 Gray')
|
||||
axs[1].imshow(img2_gray_small, cmap='gray')
|
||||
axs[1].set_title('Img2 Gary')
|
||||
axs[2].imshow(img_gray_small_diff, cmap='gray')
|
||||
axs[2].set_title('diff')
|
||||
axs[3].imshow(img_gray_small_diff_trunc, cmap='gray')
|
||||
axs[3].set_title('diff_trunct')
|
||||
plt.show()
|
||||
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
|
||||
axs[0].imshow(img1_edge, cmap='gray')
|
||||
axs[0].set_title('img1_edge')
|
||||
axs[1].imshow(img2_edge, cmap='gray')
|
||||
axs[1].set_title('img2_edge')
|
||||
plt.show()
|
||||
|
||||
info = {}
|
||||
info['match_num'] = match_num
|
||||
info['match_rate'] = match_rate
|
||||
info['mask_iou'] = mask_iou
|
||||
info['gray_diff'] = gray_diff
|
||||
info['gray_diff_trunc'] = gray_diff_trunc
|
||||
info['hausdorff_dist'] = hausdorff_dist
|
||||
return info
|
||||
|
||||
|
||||
def predict_match_success_human(info):
|
||||
match_num = info['match_num']
|
||||
match_rate = info['match_rate']
|
||||
mask_iou = info['mask_iou']
|
||||
gray_diff = info['gray_diff']
|
||||
gray_diff_trunc = info['gray_diff_trunc']
|
||||
hausdorff_dist = info['hausdorff_dist']
|
||||
|
||||
if mask_iou > 0.95:
|
||||
return True
|
||||
|
||||
if match_num < 20 or match_rate < 0.7:
|
||||
return False
|
||||
|
||||
if mask_iou > 0.80 and gray_diff < 0.040 and gray_diff_trunc < 0.010:
|
||||
return True
|
||||
|
||||
if mask_iou > 0.70 and gray_diff < 0.050 and gray_diff_trunc < 0.008:
|
||||
return True
|
||||
|
||||
'''
|
||||
if match_rate<0.70 or match_num<3000:
|
||||
return False
|
||||
|
||||
if (mask_iou>0.85 and hausdorff_dist<20)or (gray_diff<0.015 and gray_diff_trunc<0.01) or match_rate>=0.90:
|
||||
return True
|
||||
'''
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def predict_match_success(info, model=None):
|
||||
if model == None:
|
||||
return predict_match_success_human(info)
|
||||
else:
|
||||
feat_name = ['match_num', 'match_rate', 'mask_iou', 'gray_diff', 'gray_diff_trunc', 'hausdorff_dist']
|
||||
# 提取特征
|
||||
features = [info[f] for f in feat_name]
|
||||
# 预测
|
||||
pred = model.predict([features])[0]
|
||||
return pred >= 0.5
|
21
svrm/utils/log_utils.py
Normal file
@ -0,0 +1,21 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
def txt_to_img(text, font=cv2.FONT_HERSHEY_SIMPLEX, font_scale=1, font_thickness=2, img_width=1000, img_height=100, text_color=(0, 0, 0), bg_color=(255, 255, 255)):
|
||||
lines = text.split('\n')
|
||||
img_lines = []
|
||||
for line in lines:
|
||||
# 计算每行文本的尺寸
|
||||
line_size, _ = cv2.getTextSize(line, font, font_scale, font_thickness)
|
||||
line_width, line_height = line_size
|
||||
# 创建包含当前行的图像画布
|
||||
img_line = np.full((int(line_height*1.5) , img_width, 3), bg_color, dtype=np.uint8)
|
||||
text_x, text_y = 0, line_height
|
||||
cv2.putText(img_line, line, (text_x, text_y), font, font_scale, text_color, font_thickness, cv2.LINE_AA)
|
||||
|
||||
img_lines.append(img_line)
|
||||
|
||||
|
||||
# 垂直堆叠所有行图像
|
||||
img = np.vstack(img_lines)
|
||||
return img
|
4
weights/download.sh
Normal file
@ -0,0 +1,4 @@
|
||||
python3 hg_download.py Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled
|
||||
|
||||
wget https://github.com/danielgatis/rembg/releases/download/v0.0.0/u2net.onnx
|
||||
ln -s /root/.u2netxxx/u2net.onnx u2net.onnx
|
13
weights/hg_download.py
Normal file
@ -0,0 +1,13 @@
|
||||
# download huggingface pretrain model
|
||||
|
||||
import os
|
||||
import sys
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
if __name__ == "__main__":
|
||||
repo_id = sys.argv[1]
|
||||
snapshot_download(
|
||||
repo_id = repo_id,
|
||||
cache_dir = './',
|
||||
)
|
||||
print("Done")
|