mirror of
https://gitee.com/ascend/ModelLink.git
synced 2024-12-01 19:39:02 +08:00
add adapter for cann fa
This commit is contained in:
parent
3dbf8233cb
commit
14759c2704
595
ascendspeed/te/ops/csrc/aclnn_common.h
Normal file
595
ascendspeed/te/ops/csrc/aclnn_common.h
Normal file
@ -0,0 +1,595 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2022 Huawei Technologies Co., Ltd
|
||||
* All rights reserved.
|
||||
*
|
||||
* Licensed under the BSD 3-Clause License (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://opensource.org/licenses/BSD-3-Clause
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
******************************************************************************/
|
||||
|
||||
#ifndef PYTORCH_NPU_UTIL_HPP_
|
||||
#define PYTORCH_NPU_UTIL_HPP_
|
||||
|
||||
#include <torch_npu/csrc/framework/utils/CalcuOpUtil.h>
|
||||
#include <torch_npu/csrc/framework/utils/OpAdapter.h>
|
||||
#include <dlfcn.h>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
#include <type_traits>
|
||||
#include <ATen/Tensor.h>
|
||||
#include <acl/acl_base.h>
|
||||
#include <acl/acl_rt.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include "torch_npu/csrc/core/npu/NPUStream.h"
|
||||
#include "torch_npu/csrc/framework/OpCommand.h"
|
||||
#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
|
||||
#include "torch_npu/csrc/framework/interface/EnvVariables.h"
|
||||
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||
|
||||
#define NPU_NAME_SPACE at_npu::native
|
||||
|
||||
using aclOpExecutor = struct aclOpExecutor;
|
||||
using aclTensor = struct aclTensor;
|
||||
using aclScalar = struct aclScalar;
|
||||
using aclIntArray = struct aclIntArray;
|
||||
using aclFloatArray = struct aclFloatArray;
|
||||
using aclBoolArray = struct aclBoolArray;
|
||||
using aclTensorList = struct aclTensorList;
|
||||
|
||||
using _aclCreateTensor = aclTensor *(*)(const int64_t *view_dims, uint64_t view_dims_num, aclDataType data_type,
|
||||
const int64_t *stride, int64_t offset, aclFormat format, const int64_t *storage_dims, uint64_t storage_dims_num,
|
||||
void *tensor_data);
|
||||
using _aclCreateScalar = aclScalar *(*)(void *value, aclDataType data_type);
|
||||
using _aclCreateIntArray = aclIntArray *(*)(const int64_t *value, uint64_t size);
|
||||
using _aclCreateFloatArray = aclFloatArray *(*)(const float *value, uint64_t size);
|
||||
using _aclCreateBoolArray = aclBoolArray *(*)(const bool *value, uint64_t size);
|
||||
using _aclCreateTensorList = aclTensorList *(*)(const aclTensor *const *value, uint64_t size);
|
||||
|
||||
using _aclDestroyTensor = int (*)(const aclTensor *tensor);
|
||||
using _aclDestroyScalar = int (*)(const aclScalar *scalar);
|
||||
using _aclDestroyIntArray = int (*)(const aclIntArray *array);
|
||||
using _aclDestroyFloatArray = int (*)(const aclFloatArray *array);
|
||||
using _aclDestroyBoolArray = int (*)(const aclBoolArray *array);
|
||||
using _aclDestroyTensorList = int (*)(const aclTensorList *array);
|
||||
|
||||
constexpr int kHashBufSize = 8192;
|
||||
constexpr int kHashBufMaxSize = kHashBufSize + 1024;
|
||||
extern thread_local char g_hashBuf[kHashBufSize];
|
||||
extern thread_local int g_hashOffset;
|
||||
|
||||
#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \
|
||||
_(at::ScalarType::Byte, ACL_UINT8) \
|
||||
_(at::ScalarType::Char, ACL_INT8) \
|
||||
_(at::ScalarType::Short, ACL_INT16) \
|
||||
_(at::ScalarType::Int, ACL_INT32) \
|
||||
_(at::ScalarType::Long, ACL_INT64) \
|
||||
_(at::ScalarType::Half, ACL_FLOAT16) \
|
||||
_(at::ScalarType::Float, ACL_FLOAT) \
|
||||
_(at::ScalarType::Double, ACL_DOUBLE) \
|
||||
_(at::ScalarType::ComplexHalf, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \
|
||||
_(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \
|
||||
_(at::ScalarType::Bool, ACL_BOOL) \
|
||||
_(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::BFloat16, ACL_BF16) \
|
||||
_(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \
|
||||
_(at::ScalarType::NumOptions, ACL_DT_UNDEFINED)
|
||||
|
||||
constexpr aclDataType kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(at::ScalarType::NumOptions) + 1] = {
|
||||
#define DEFINE_ENUM(_1, n) n,
|
||||
AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM)
|
||||
#undef DEFINE_ENUM
|
||||
};
|
||||
|
||||
#define GET_OP_API_FUNC(apiName) reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName))
|
||||
|
||||
#define MEMCPY_TO_BUF(data_expression, size_expression) \
|
||||
if (g_hashOffset + (size_expression) > kHashBufSize) { \
|
||||
g_hashOffset = kHashBufMaxSize; \
|
||||
return; \
|
||||
} \
|
||||
memcpy(g_hashBuf + g_hashOffset, data_expression, size_expression); \
|
||||
g_hashOffset += size_expression;
|
||||
|
||||
inline const char *GetOpApiLibName(void)
|
||||
{
|
||||
return "libopapi.so";
|
||||
}
|
||||
|
||||
inline const char *GetCustOpApiLibName(void)
|
||||
{
|
||||
return "libcust_opapi.so";
|
||||
}
|
||||
|
||||
inline void *GetOpApiFuncAddrInLib(void *handler, const char *libName, const char *apiName)
|
||||
{
|
||||
auto funcAddr = dlsym(handler, apiName);
|
||||
if (funcAddr == nullptr) {
|
||||
ASCEND_LOGW("dlsym %s from %s failed, error:%s.", apiName, libName, dlerror());
|
||||
}
|
||||
return funcAddr;
|
||||
}
|
||||
|
||||
inline void *GetOpApiLibHandler(const char *libName)
|
||||
{
|
||||
auto handler = dlopen(libName, RTLD_LAZY);
|
||||
if (handler == nullptr) {
|
||||
ASCEND_LOGW("dlopen %s failed, error:%s.", libName, dlerror());
|
||||
}
|
||||
return handler;
|
||||
}
|
||||
|
||||
inline void *GetOpApiFuncAddr(const char *apiName)
|
||||
{
|
||||
static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName());
|
||||
if (custOpApiHandler != nullptr) {
|
||||
auto funcAddr = GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName);
|
||||
if (funcAddr != nullptr) {
|
||||
return funcAddr;
|
||||
}
|
||||
}
|
||||
|
||||
static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName());
|
||||
if (opApiHandler == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName);
|
||||
}
|
||||
|
||||
inline c10::Scalar ConvertTensorToScalar(const at::Tensor &tensor)
|
||||
{
|
||||
c10::Scalar expScalar;
|
||||
const at::Tensor *aclInput = &tensor;
|
||||
if (aclInput->scalar_type() == at::ScalarType::Double) {
|
||||
double value = *(double *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Long) {
|
||||
int64_t value = *(int64_t *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Float) {
|
||||
float value = *(float *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Int) {
|
||||
int value = *(int *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Half) {
|
||||
c10::Half value = *(c10::Half *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::Bool) {
|
||||
int8_t value = *(int8_t *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::ComplexDouble) {
|
||||
c10::complex<double> value = *(c10::complex<double> *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::ComplexFloat) {
|
||||
c10::complex<float> value = *(c10::complex<float> *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else if (aclInput->scalar_type() == at::ScalarType::BFloat16) {
|
||||
c10::BFloat16 value = *(c10::BFloat16 *)aclInput->data_ptr();
|
||||
c10::Scalar scalar(value);
|
||||
expScalar = scalar;
|
||||
} else {
|
||||
NPU_LOGE("unsupported scalar type! ");
|
||||
}
|
||||
return expScalar;
|
||||
}
|
||||
|
||||
inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor)
|
||||
{
|
||||
at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory();
|
||||
int deviceIndex = 0;
|
||||
return cpuPinMemTensor.to(
|
||||
c10::Device(at_npu::key::NativeDeviceType, deviceIndex), cpuPinMemTensor.scalar_type(), true, true);
|
||||
}
|
||||
|
||||
inline at::Tensor CopyScalarToDevice(const c10::Scalar &cpu_scalar, at::ScalarType scalar_data_type)
|
||||
{
|
||||
return CopyTensorHostToDevice(scalar_to_tensor(cpu_scalar).to(scalar_data_type));
|
||||
}
|
||||
|
||||
inline aclTensor *ConvertType(const at::Tensor &at_tensor)
|
||||
{
|
||||
static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor);
|
||||
if (aclCreateTensor == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (!at_tensor.defined()) {
|
||||
return nullptr;
|
||||
}
|
||||
at::ScalarType scalar_data_type = at_tensor.scalar_type();
|
||||
aclDataType acl_data_type = kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(scalar_data_type)];
|
||||
TORCH_CHECK(
|
||||
acl_data_type != ACL_DT_UNDEFINED, std::string(c10::toString(scalar_data_type)) + " has not been supported")
|
||||
c10::SmallVector<int64_t, 5> storageDims;
|
||||
// if acl_data_type is ACL_STRING, storageDims is empty.
|
||||
auto itemsize = at_tensor.itemsize();
|
||||
if (itemsize == 0) {
|
||||
AT_ERROR("When ConvertType, tensor item size of cannot be zero.");
|
||||
return nullptr;
|
||||
}
|
||||
if (acl_data_type != ACL_STRING) {
|
||||
storageDims.push_back(at_tensor.storage().nbytes() / itemsize);
|
||||
}
|
||||
|
||||
const auto dimNum = at_tensor.sizes().size();
|
||||
aclFormat format = ACL_FORMAT_ND;
|
||||
switch (dimNum) {
|
||||
case 3:
|
||||
format = ACL_FORMAT_NCL;
|
||||
break;
|
||||
case 4:
|
||||
format = ACL_FORMAT_NCHW;
|
||||
break;
|
||||
case 5:
|
||||
format = ACL_FORMAT_NCDHW;
|
||||
break;
|
||||
default:
|
||||
format = ACL_FORMAT_ND;
|
||||
}
|
||||
|
||||
if (at_tensor.unsafeGetTensorImpl()->is_wrapped_number()) {
|
||||
c10::Scalar expScalar = ConvertTensorToScalar(at_tensor);
|
||||
at::Tensor aclInput = CopyScalarToDevice(expScalar, scalar_data_type);
|
||||
return aclCreateTensor(aclInput.sizes().data(),
|
||||
aclInput.sizes().size(),
|
||||
acl_data_type,
|
||||
aclInput.strides().data(),
|
||||
aclInput.storage_offset(),
|
||||
format,
|
||||
storageDims.data(),
|
||||
storageDims.size(),
|
||||
aclInput.storage().data());
|
||||
}
|
||||
|
||||
auto acl_tensor = aclCreateTensor(at_tensor.sizes().data(),
|
||||
at_tensor.sizes().size(),
|
||||
acl_data_type,
|
||||
at_tensor.strides().data(),
|
||||
at_tensor.storage_offset(),
|
||||
format,
|
||||
storageDims.data(),
|
||||
storageDims.size(),
|
||||
at_tensor.storage().data());
|
||||
return acl_tensor;
|
||||
}
|
||||
|
||||
inline aclScalar *ConvertType(const at::Scalar &at_scalar)
|
||||
{
|
||||
static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar);
|
||||
if (aclCreateScalar == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
at::ScalarType scalar_data_type = at_scalar.type();
|
||||
aclDataType acl_data_type = kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(scalar_data_type)];
|
||||
TORCH_CHECK(
|
||||
acl_data_type != ACL_DT_UNDEFINED, std::string(c10::toString(scalar_data_type)) + " has not been supported")
|
||||
aclScalar *acl_scalar = nullptr;
|
||||
switch (scalar_data_type) {
|
||||
case at::ScalarType::Double: {
|
||||
double value = at_scalar.toDouble();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Long: {
|
||||
int64_t value = at_scalar.toLong();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::Bool: {
|
||||
bool value = at_scalar.toBool();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
case at::ScalarType::ComplexDouble: {
|
||||
auto value = at_scalar.toComplexDouble();
|
||||
acl_scalar = aclCreateScalar(&value, acl_data_type);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
acl_scalar = nullptr;
|
||||
break;
|
||||
}
|
||||
return acl_scalar;
|
||||
}
|
||||
|
||||
inline aclIntArray *ConvertType(const at::IntArrayRef &at_array)
|
||||
{
|
||||
static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray);
|
||||
if (aclCreateIntArray == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto array = aclCreateIntArray(at_array.data(), at_array.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
inline aclBoolArray *ConvertType(const std::array<bool, N> &value)
|
||||
{
|
||||
static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray);
|
||||
if (aclCreateBoolArray == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto array = aclCreateBoolArray(value.data(), value.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
inline aclBoolArray *ConvertType(const at::ArrayRef<bool> &value)
|
||||
{
|
||||
static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray);
|
||||
if (aclCreateBoolArray == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto array = aclCreateBoolArray(value.data(), value.size());
|
||||
return array;
|
||||
}
|
||||
|
||||
inline aclTensorList *ConvertType(const at::TensorList &at_tensor_list)
|
||||
{
|
||||
static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList);
|
||||
if (aclCreateTensorList == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<const aclTensor *> tensor_list(at_tensor_list.size());
|
||||
for (size_t i = 0; i < at_tensor_list.size(); i++) {
|
||||
tensor_list[i] = ConvertType(at_tensor_list[i]);
|
||||
}
|
||||
auto acl_tensor_list = aclCreateTensorList(tensor_list.data(), tensor_list.size());
|
||||
return acl_tensor_list;
|
||||
}
|
||||
|
||||
inline aclTensor *ConvertType(const c10::optional<at::Tensor> &opt_tensor)
|
||||
{
|
||||
if (opt_tensor.has_value() && opt_tensor.value().defined()) {
|
||||
return ConvertType(opt_tensor.value());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline aclIntArray *ConvertType(const c10::optional<at::IntArrayRef> &opt_array)
|
||||
{
|
||||
if (opt_array.has_value()) {
|
||||
return ConvertType(opt_array.value());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline aclScalar *ConvertType(const c10::optional<at::Scalar> &opt_scalar)
|
||||
{
|
||||
if (opt_scalar.has_value()) {
|
||||
return ConvertType(opt_scalar.value());
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
inline aclDataType ConvertType(const at::ScalarType scalarType)
|
||||
{
|
||||
return kATenScalarTypeToAclDataTypeTable[static_cast<int64_t>(scalarType)];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
T ConvertType(T value)
|
||||
{
|
||||
return value;
|
||||
}
|
||||
|
||||
template <typename Tuple, size_t... I>
|
||||
auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr, std::index_sequence<I...>)
|
||||
{
|
||||
using OpApiFunc = int (*)(typename std::decay<decltype(std::get<I>(params))>::type...);
|
||||
auto func = reinterpret_cast<OpApiFunc>(opApiAddr);
|
||||
return func;
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr)
|
||||
{
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
return ConvertToOpApiFunc(params, opApiAddr, std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
inline void Release(aclTensor *p)
|
||||
{
|
||||
static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor);
|
||||
if (aclDestroyTensor == nullptr) {
|
||||
return;
|
||||
}
|
||||
aclDestroyTensor(p);
|
||||
}
|
||||
|
||||
inline void Release(aclScalar *p)
|
||||
{
|
||||
static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar);
|
||||
if (aclDestroyScalar == nullptr) {
|
||||
return;
|
||||
}
|
||||
aclDestroyScalar(p);
|
||||
}
|
||||
|
||||
inline void Release(aclIntArray *p)
|
||||
{
|
||||
static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray);
|
||||
if (aclDestroyIntArray == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
aclDestroyIntArray(p);
|
||||
}
|
||||
|
||||
inline void Release(aclBoolArray *p)
|
||||
{
|
||||
static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray);
|
||||
if (aclDestroyBoolArray == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
aclDestroyBoolArray(p);
|
||||
}
|
||||
|
||||
inline void Release(aclTensorList *p)
|
||||
{
|
||||
static const auto aclDestroyTensorList = GET_OP_API_FUNC(aclDestroyTensorList);
|
||||
if (aclDestroyTensorList == nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
aclDestroyTensorList(p);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Release(T value)
|
||||
{
|
||||
(void)value;
|
||||
}
|
||||
|
||||
template <typename Tuple, size_t... I>
|
||||
void CallRelease(Tuple t, std::index_sequence<I...>)
|
||||
{
|
||||
(void)std::initializer_list<int>{(Release(std::get<I>(t)), 0)...};
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
void ReleaseConvertTypes(Tuple &t)
|
||||
{
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
CallRelease(t, std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
constexpr auto ConvertTypes(Ts &...args)
|
||||
{
|
||||
return std::make_tuple(ConvertType(args)...);
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple, size_t... I>
|
||||
auto call(Function f, Tuple t, std::index_sequence<I...>)
|
||||
{
|
||||
return f(std::get<I>(t)...);
|
||||
}
|
||||
|
||||
template <typename Function, typename Tuple>
|
||||
auto call(Function f, Tuple t)
|
||||
{
|
||||
static constexpr auto size = std::tuple_size<Tuple>::value;
|
||||
return call(f, t, std::make_index_sequence<size>{});
|
||||
}
|
||||
|
||||
template <std::size_t N>
|
||||
void AddParamToBuf(const std::array<bool, N> &value)
|
||||
{
|
||||
MEMCPY_TO_BUF(value.data(), value.size() * sizeof(bool));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void AddParamToBuf(const T &value)
|
||||
{
|
||||
MEMCPY_TO_BUF(&value, sizeof(T));
|
||||
}
|
||||
|
||||
void AddParamToBuf(const at::Tensor &);
|
||||
void AddParamToBuf(const at::Scalar &);
|
||||
void AddParamToBuf(const at::IntArrayRef &);
|
||||
void AddParamToBuf(const at::ArrayRef<bool> &);
|
||||
void AddParamToBuf(const at::TensorList &);
|
||||
void AddParamToBuf(const c10::optional<at::Tensor> &);
|
||||
void AddParamToBuf(const c10::optional<at::IntArrayRef> &);
|
||||
void AddParamToBuf(const c10::optional<at::Scalar> &);
|
||||
void AddParamToBuf(const at::ScalarType);
|
||||
void AddParamToBuf(const string &);
|
||||
void AddParamToBuf();
|
||||
|
||||
template <typename T, typename... Args>
|
||||
void AddParamToBuf(const T &arg, Args &...args)
|
||||
{
|
||||
AddParamToBuf(arg);
|
||||
AddParamToBuf(args...);
|
||||
}
|
||||
|
||||
uint64_t CalcHashId();
|
||||
using InitHugeMemThreadLocal = int (*)(void *, bool);
|
||||
using UnInitHugeMemThreadLocal = void (*)(void *, bool);
|
||||
using ReleaseHugeMem = void (*)(void *, bool);
|
||||
|
||||
#define ACLNN_CMD(aclnn_api, ...) \
|
||||
do { \
|
||||
static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \
|
||||
static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \
|
||||
static const auto initMemAddr = GetOpApiFuncAddr("InitHugeMemThreadLocal"); \
|
||||
static const auto unInitMemAddr = GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \
|
||||
static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \
|
||||
TORCH_CHECK(getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, \
|
||||
#aclnn_api, \
|
||||
" or ", \
|
||||
#aclnn_api "GetWorkspaceSize", \
|
||||
" not in ", \
|
||||
GetOpApiLibName(), \
|
||||
", or ", \
|
||||
GetOpApiLibName(), \
|
||||
"not found."); \
|
||||
auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \
|
||||
uint64_t workspace_size = 0; \
|
||||
uint64_t *workspace_size_addr = &workspace_size; \
|
||||
aclOpExecutor *executor = nullptr; \
|
||||
aclOpExecutor **executor_addr = &executor; \
|
||||
InitHugeMemThreadLocal initMemFunc = reinterpret_cast<InitHugeMemThreadLocal>(initMemAddr); \
|
||||
UnInitHugeMemThreadLocal unInitMemFunc = reinterpret_cast<UnInitHugeMemThreadLocal>(unInitMemAddr); \
|
||||
if (initMemFunc) { \
|
||||
initMemFunc(nullptr, false); \
|
||||
} \
|
||||
auto converted_params = ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \
|
||||
static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \
|
||||
auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \
|
||||
TORCH_CHECK(workspace_status == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \
|
||||
void *workspace_addr = nullptr; \
|
||||
if (workspace_size != 0) { \
|
||||
at::TensorOptions options = at::TensorOptions(torch_npu::utils::get_npu_device_type()); \
|
||||
auto workspace_tensor = at::empty({workspace_size}, options.dtype(at::kByte)); \
|
||||
workspace_addr = workspace_tensor.storage().data(); \
|
||||
} \
|
||||
auto acl_call = [converted_params, workspace_addr, workspace_size, acl_stream, executor]() -> int { \
|
||||
typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, const aclrtStream); \
|
||||
OpApiFunc opApiFunc = reinterpret_cast<OpApiFunc>(opApiFuncAddr); \
|
||||
auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \
|
||||
TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \
|
||||
ReleaseConvertTypes(converted_params); \
|
||||
ReleaseHugeMem releaseMemFunc = reinterpret_cast<ReleaseHugeMem>(releaseMemAddr); \
|
||||
if (releaseMemFunc) { \
|
||||
releaseMemFunc(nullptr, false); \
|
||||
} \
|
||||
return api_ret; \
|
||||
}; \
|
||||
at_npu::native::OpCommand cmd; \
|
||||
cmd.Name(#aclnn_api); \
|
||||
cmd.SetCustomHandler(acl_call); \
|
||||
cmd.Run(); \
|
||||
if (unInitMemFunc) { \
|
||||
unInitMemFunc(nullptr, false); \
|
||||
} \
|
||||
} while (false)
|
||||
|
||||
#endif // PYTORCH_NPU_HELPER_HPP_
|
@ -1,46 +0,0 @@
|
||||
// Copyright (c) 2023 Huawei Technologies Co., Ltd
|
||||
// Copyright (c) 2019, Facebook CORPORATION.
|
||||
// All rights reserved.
|
||||
//
|
||||
// Licensed under the BSD 3-Clause License (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://opensource.org/licenses/BSD-3-Clause
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <torch/csrc/autograd/custom_function.h>
|
||||
#include "torch_npu/csrc/core/npu/SecondaryStreamGuard.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
|
||||
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||
#include "torch_npu/csrc/aten/ops/op_api/op_api_common.h"
|
||||
#include "torch_npu/csrc/aten/NPUGeneratorImpl.h"
|
||||
#include "torch_npu/csrc/framework/OpCommand.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
|
||||
|
||||
#include "atb/operation_torch.h"
|
||||
#include "common.h"
|
||||
|
||||
using namespace std;
|
||||
torch::Tensor add(const torch::Tensor &input1, const torch::Tensor &input2)
|
||||
{
|
||||
OperationTorch atb("AddOperation");
|
||||
|
||||
std::vector<torch::Tensor> atInTensors;
|
||||
std::vector<torch::Tensor> atOutTensors;
|
||||
atInTensors.push_back(input1);
|
||||
atInTensors.push_back(input2);
|
||||
atb.SetParam("{\"scale\": 1}");
|
||||
atOutTensors = atb.Execute(atInTensors);
|
||||
|
||||
return atOutTensors[0];
|
||||
}
|
@ -16,27 +16,19 @@
|
||||
#ifndef ASCENDSPEED_TE_OPS_CSRC_COMMON_H
|
||||
#define ASCENDSPEED_TE_OPS_CSRC_COMMON_H
|
||||
|
||||
#include <torch/csrc/autograd/custom_function.h>
|
||||
|
||||
#include "torch_npu/csrc/core/npu/SecondaryStreamGuard.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
|
||||
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||
#include "torch_npu/csrc/aten/ops/op_api/op_api_common.h"
|
||||
#include "torch_npu/csrc/aten/NPUGeneratorImpl.h"
|
||||
#include "torch_npu/csrc/framework/OpCommand.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
|
||||
|
||||
#include <torch_npu/csrc/core/npu/NPUStream.h>
|
||||
#include <torch_npu/csrc/core/npu/DeviceUtils.h>
|
||||
#include <torch_npu/csrc/framework/OpCommand.h>
|
||||
#include <torch/script.h>
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t, int64_t> npu_flash_attention(
|
||||
const at::Tensor &query, const at::Tensor &key,
|
||||
const at::Tensor &value, int64_t head_num, c10::string_view input_layout,
|
||||
const at::Tensor &value, int64_t head_num, std::string input_layout,
|
||||
const c10::optional<at::Tensor> &pse_opt, const c10::optional<at::Tensor> &padding_mask_opt,
|
||||
const c10::optional<at::Tensor> &atten_mask_opt,
|
||||
double scale, double keep_prob, int64_t pre_tockens, int64_t next_tockens, int64_t inner_precise,
|
||||
bool gen_mask_parallel, bool sync);
|
||||
c10::optional<at::IntArrayRef> prefix_opt, int64_t sparse_mode, bool gen_mask_parallel, bool sync);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_grad(
|
||||
const at::Tensor &query,
|
||||
@ -44,7 +36,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_g
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &dy,
|
||||
int64_t head_num,
|
||||
c10::string_view input_layout,
|
||||
std::string input_layout,
|
||||
const c10::optional<at::Tensor> &pse,
|
||||
const c10::optional<at::Tensor> &padding_mask,
|
||||
const c10::optional<at::Tensor> &atten_mask,
|
||||
@ -60,9 +52,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_g
|
||||
int64_t seed,
|
||||
int64_t offset,
|
||||
int64_t numels,
|
||||
c10::optional<at::IntArrayRef> prefix,
|
||||
int64_t sparse_mode,
|
||||
bool gen_mask_parallel,
|
||||
bool sync);
|
||||
|
||||
torch::Tensor add(const torch::Tensor &input1, const torch::Tensor &input2);
|
||||
|
||||
#endif
|
||||
|
@ -14,22 +14,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <torch/csrc/autograd/custom_function.h>
|
||||
#include "torch_npu/csrc/core/npu/SecondaryStreamGuard.h"
|
||||
#include "torch_npu/csrc/core/npu/NPUCachingAllocator.h"
|
||||
#include "torch_npu/csrc/framework/utils/RandomOpAdapter.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpAdapter.h"
|
||||
#include "torch_npu/csrc/aten/NPUNativeFunctions.h"
|
||||
#include "torch_npu/csrc/aten/ops/op_api/op_api_common.h"
|
||||
#include "torch_npu/csrc/aten/NPUGeneratorImpl.h"
|
||||
#include "torch_npu/csrc/framework/OpCommand.h"
|
||||
#include "torch_npu/csrc/framework/utils/OpPreparation.h"
|
||||
#include "torch_npu/csrc/aten/CustomFunctions.h"
|
||||
#include "common.h"
|
||||
#include "aclnn_common.h"
|
||||
|
||||
const static int FLASH_THRESHOLD = 512;
|
||||
using namespace at_npu::native;
|
||||
const static int N = 32;
|
||||
|
||||
enum class DropOutStatus {
|
||||
DROPOUT_NORMAL = 0,
|
||||
@ -52,7 +44,7 @@ at::Tensor format_trans(const at::Tensor &at_tensor)
|
||||
{
|
||||
if (at_tensor.defined()) {
|
||||
TORCH_CHECK(torch_npu::utils::is_npu(at_tensor), "only npu tensor is supported");
|
||||
return custom_ops::npu_format_cast(at_tensor, ACL_FORMAT_ND);
|
||||
return at_npu::native::NPUNativeFunctions::npu_format_cast(at_tensor, ACL_FORMAT_ND);
|
||||
}
|
||||
return at_tensor;
|
||||
}
|
||||
@ -62,16 +54,16 @@ at::Tensor dropout_gen_mask_impl(const at::Tensor &query, const at::Scalar &keep
|
||||
{
|
||||
int64_t length = (numels + 128 - 1) / 128 * 128 / 8;
|
||||
c10::TensorOptions options = query.options();
|
||||
at::Tensor mask = OpPreparation::apply_tensor_without_format(at::IntArrayRef{length + 32}, options.dtype(at::kByte));
|
||||
at::SmallVector<int64_t, ::N> offsetList = {0, offset};
|
||||
at::Tensor mask = at::empty(at::IntArrayRef{length + 32}, options.dtype(at::kByte));
|
||||
at::SmallVector<int64_t, N> offsetList = {0, offset};
|
||||
const int64_t seed1 = 0;
|
||||
OpCommand cmd;
|
||||
at_npu::native::OpCommand cmd;
|
||||
cmd.Name("StatelessDropOutGenMask")
|
||||
.Input(at::IntArrayRef{numels})
|
||||
.Input(keep_prob, query.scalar_type(), CompileType::MEMORY_HOST_COMPILE_DEPENDENT)
|
||||
.Input(keep_prob, query.scalar_type(), at_npu::native::CompileType::MEMORY_HOST_COMPILE_DEPENDENT)
|
||||
.Input(seed, at::ScalarType::Int)
|
||||
.Input(at::Scalar(seed1), at::ScalarType::Int)
|
||||
.Input(offsetList, at::kLong, CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
|
||||
.Input(offsetList, at::kLong, at_npu::native::CompileType::MEMORY_HOST_COMPILE_INDEPENDENT)
|
||||
.Output(mask)
|
||||
.Run();
|
||||
return mask;
|
||||
@ -91,9 +83,6 @@ at::Tensor dropout_gen_mask_dispatch(const at::Tensor &query, const at::Scalar &
|
||||
// alloced from the pool of the secondary stream.
|
||||
c10_npu::SecondaryStreamGuard guard(c10_npu::getCurrentSecondaryStream());
|
||||
mask = dropout_gen_mask_impl(query, keep_prob, seed, offset, numels);
|
||||
if (sync) {
|
||||
NPU_CHECK_ERROR(c10_npu::acl::AclrtSynchronizeStreamWithTimeout(original_stream));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
mask = dropout_gen_mask_impl(query, keep_prob, seed, offset, numels);
|
||||
@ -148,7 +137,9 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_b
|
||||
double keep_prob,
|
||||
int64_t pre_tockens,
|
||||
int64_t next_tockens,
|
||||
int64_t inner_precise)
|
||||
int64_t inner_precise,
|
||||
c10::optional<at::IntArrayRef> prefix,
|
||||
int64_t sparse_mode)
|
||||
{
|
||||
double scale = scale_value;
|
||||
|
||||
@ -160,6 +151,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_b
|
||||
const at::Tensor &softmax_sum_const = softmax_sum.value_or(at::Tensor());
|
||||
const at::Tensor &softmax_const = softmax_in.value_or(at::Tensor());
|
||||
const at::Tensor &attention_const = attention_in.value_or(at::Tensor());
|
||||
auto prefixN = prefix.value_or(at::IntArrayRef{});
|
||||
|
||||
at::Tensor format_query = format_trans(query);
|
||||
at::Tensor format_key = format_trans(key);
|
||||
@ -174,22 +166,22 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_b
|
||||
at::Tensor format_softmax_sum = format_trans(softmax_sum_const);
|
||||
at::Tensor format_softmax = format_trans(softmax_const);
|
||||
at::Tensor format_attention = format_trans(attention_const);
|
||||
at::Tensor dq = OpPreparation::apply_tensor_without_format(format_query);
|
||||
at::Tensor dk = OpPreparation::apply_tensor_without_format(format_key);
|
||||
at::Tensor dv = OpPreparation::apply_tensor_without_format(format_value);
|
||||
at::Tensor dq = at::empty(format_query.sizes(), format_query.options());
|
||||
at::Tensor dk = at::empty(format_key.sizes(), format_key.options());
|
||||
at::Tensor dv = at::empty(format_value.sizes(), format_value.options());
|
||||
char* input_layout_ptr = const_cast<char *>(input_layout.c_str());
|
||||
at::Tensor dpse;
|
||||
if (format_pse.defined()) {
|
||||
dpse = OpPreparation::apply_tensor_without_format(format_pse);
|
||||
dpse = at::empty(format_pse.sizes(), format_pse.options());
|
||||
} else {
|
||||
dpse = at::empty({0}, query.options());
|
||||
}
|
||||
|
||||
EXEC_NPU_NO_FORMAT_CHECK_CMD(
|
||||
ACLNN_CMD(
|
||||
aclnnFlashAttentionScoreGrad, format_query, format_key, format_value, format_dy,
|
||||
format_pse, format_drop_mask, format_padding_mask, format_atten_mask,
|
||||
format_softmax_max, format_softmax_sum, format_softmax, format_attention, scale_value, keep_prob,
|
||||
pre_tockens, next_tockens, head_num, input_layout_ptr, inner_precise, dq, dk, dv, dpse);
|
||||
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, format_softmax_max,
|
||||
format_softmax_sum, format_softmax, format_attention, prefixN, scale_value, keep_prob, pre_tockens,
|
||||
next_tockens, head_num, input_layout_ptr, inner_precise, sparse_mode, dq, dk, dv, dpse);
|
||||
|
||||
if (!format_pse.defined()) {
|
||||
at::Tensor dpse_required;
|
||||
@ -205,7 +197,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_g
|
||||
const at::Tensor &value,
|
||||
const at::Tensor &dy,
|
||||
int64_t head_num,
|
||||
c10::string_view input_layout,
|
||||
std::string input_layout,
|
||||
const c10::optional<at::Tensor> &pse,
|
||||
const c10::optional<at::Tensor> &padding_mask,
|
||||
const c10::optional<at::Tensor> &atten_mask,
|
||||
@ -221,6 +213,8 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_g
|
||||
int64_t seed,
|
||||
int64_t offset,
|
||||
int64_t numels,
|
||||
c10::optional<at::IntArrayRef> prefix,
|
||||
int64_t sparse_mode,
|
||||
bool gen_mask_parallel,
|
||||
bool sync)
|
||||
{
|
||||
@ -247,7 +241,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_g
|
||||
auto result = npu_flash_attention_backward(query,
|
||||
key, value, dy, head_num, input_layout_str, pse, drop_mask, padding_mask, atten_mask,
|
||||
softmax_max, softmax_sum, softmax_in, attention_in, scale_value,
|
||||
keep_prob, pre_tockens, next_tockens, inner_precise);
|
||||
keep_prob, pre_tockens, next_tockens, inner_precise, prefix, sparse_mode);
|
||||
if (!sync) {
|
||||
c10_npu::NPUEvent npu_event;
|
||||
npu_event.record(c10_npu::getCurrentNPUStream());
|
||||
@ -259,20 +253,22 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> npu_flash_attention_g
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t, int64_t> npu_flash_attention(
|
||||
const at::Tensor &query, const at::Tensor &key,
|
||||
const at::Tensor &value, int64_t head_num, c10::string_view input_layout,
|
||||
const at::Tensor &value, int64_t head_num, std::string input_layout,
|
||||
const c10::optional<at::Tensor> &pse_opt, const c10::optional<at::Tensor> &padding_mask_opt,
|
||||
const c10::optional<at::Tensor> &atten_mask_opt,
|
||||
double scale, double keep_prob, int64_t pre_tockens, int64_t next_tockens, int64_t inner_precise,
|
||||
bool gen_mask_parallel, bool sync)
|
||||
c10::optional<at::IntArrayRef> prefix_opt, int64_t sparse_mode, bool gen_mask_parallel, bool sync)
|
||||
{
|
||||
const at::Tensor &pse = pse_opt.value_or(at::Tensor());
|
||||
const at::Tensor &padding_mask = padding_mask_opt.value_or(at::Tensor());
|
||||
const at::Tensor &atten_mask = atten_mask_opt.value_or(at::Tensor());
|
||||
auto prefixN = prefix_opt.value_or(at::IntArrayRef{});
|
||||
|
||||
TORCH_CHECK(query.dim() == 3 || query.dim() == 4, "The shapes of the input query should be 3 or 4 dimensional, but got ", query.dim(), "-dimensional");
|
||||
TORCH_CHECK(key.dim() == 3 || key.dim() == 4, "The shapes of the input key should be 3 or 4 dimensional, but got ", key.dim(), "-dimensional");
|
||||
TORCH_CHECK(value.dim() == 3 || value.dim() == 4, "The shapes of the input value should be 3 or 4 dimensional, but got ", value.dim(), "-dimensional");
|
||||
TORCH_CHECK(keep_prob >= 0 && keep_prob <= 1, "The keep_prob value must be in range of [0, 1], but got ", keep_prob);
|
||||
TORCH_CHECK(sparse_mode >= 0 && sparse_mode <= 5, "The sparse_mode value must be in range of [0~5], but got ", sparse_mode);
|
||||
std::string input_layout_str = std::string(input_layout);
|
||||
for (auto &c : input_layout_str) {
|
||||
c = toupper(c);
|
||||
@ -313,7 +309,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t, int
|
||||
double scale_value = scale;
|
||||
|
||||
at::Tensor format_query = format_trans(query);
|
||||
at::Tensor attention_score = OpPreparation::apply_tensor_without_format(format_query);
|
||||
at::Tensor attention_score = at::empty(format_query.sizes(), format_query.options());
|
||||
at::Tensor format_key = format_trans(key);
|
||||
at::Tensor format_value = format_trans(value);
|
||||
|
||||
@ -331,17 +327,15 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, int64_t, int64_t, int
|
||||
at::Tensor softmax_sum;
|
||||
at::Tensor softmax_out;
|
||||
|
||||
softmax_max = OpPreparation::apply_tensor_without_format({B, head_num, S0, 8},
|
||||
query.options().dtype(at::kFloat)); // [B, N, S0, 8]
|
||||
softmax_sum = OpPreparation::apply_tensor_without_format({B, head_num, S0, 8},
|
||||
query.options().dtype(at::kFloat)); // [B, N, S0, 8]
|
||||
softmax_max = at::empty({B, head_num, S0, 8}, query.options().dtype(at::kFloat)); // [B, N, S0, 8]
|
||||
softmax_sum = at::empty({B, head_num, S0, 8}, query.options().dtype(at::kFloat)); // [B, N, S0, 8]
|
||||
softmax_out = at::empty({0}, query.options());
|
||||
|
||||
char* input_layout_ptr = const_cast<char *>(input_layout_str.c_str());
|
||||
EXEC_NPU_NO_FORMAT_CHECK_CMD(aclnnFlashAttentionScore, format_query, format_key, format_value,
|
||||
format_pse, format_drop_mask, format_padding_mask, format_atten_mask,
|
||||
ACLNN_CMD(aclnnFlashAttentionScore, format_query, format_key, format_value,
|
||||
format_pse, format_drop_mask, format_padding_mask, format_atten_mask, prefixN,
|
||||
scale, keep_prob, pre_tockens, next_tockens, head_num, input_layout_ptr, inner_precise,
|
||||
softmax_max, softmax_sum, softmax_out, attention_score);
|
||||
sparse_mode, softmax_max, softmax_sum, softmax_out, attention_score);
|
||||
|
||||
if (!sync) {
|
||||
c10_npu::NPUEvent npu_event;
|
||||
|
@ -18,7 +18,6 @@
|
||||
#include "common.h"
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("npu_flash_attention", &npu_flash_attention, "flash attention forward");
|
||||
m.def("npu_flash_attention_grad", &npu_flash_attention_grad, "flash attention backward");
|
||||
m.def("npu_add", &add, "test transformer boost add op");
|
||||
m.def("npu_fusion_attention", &npu_flash_attention, "fusion attention forward");
|
||||
m.def("npu_fusion_attention_grad", &npu_flash_attention_grad, "fusion attention backward");
|
||||
}
|
||||
|
@ -5,41 +5,45 @@ import ascendspeed_te_ops
|
||||
|
||||
class AscendFlashAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, query, key, value, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None,
|
||||
scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=1,
|
||||
gen_mask_parallel=True, sync=False):
|
||||
def forward(ctx, query, key, value, head_num, input_layout, pse=None, padding_mask=None,
|
||||
atten_mask=None, scale=1., keep_prob=1., pre_tokens=2147483647, next_tokens=2147483647,
|
||||
inner_precise=0, prefix=None, sparse_mode=0, gen_mask_parallel=True, sync=False):
|
||||
|
||||
outputs = ascendspeed_te_ops.npu_fusion_attention(query,
|
||||
key, value, head_num, input_layout, pse, padding_mask,
|
||||
atten_mask, scale, keep_prob, pre_tokens, next_tokens,
|
||||
inner_precise, prefix, sparse_mode, gen_mask_parallel, sync)
|
||||
attention_in, softmax_max, softmax_sum, softmax_in, seed, offset, numels = outputs
|
||||
ctx.save_for_backward(
|
||||
query, key, value, pse, padding_mask, atten_mask, attention_in, softmax_max,
|
||||
softmax_sum, softmax_in)
|
||||
ctx.scale = scale
|
||||
ctx.input_layout = input_layout
|
||||
ctx.head_num = head_num
|
||||
ctx.pre_tokens = pre_tockens
|
||||
ctx.next_tokens = next_tockens
|
||||
ctx.pre_tokens = pre_tokens
|
||||
ctx.next_tokens = next_tokens
|
||||
ctx.inner_precies = inner_precise
|
||||
ctx.gen_mask_parallel = gen_mask_parallel
|
||||
ctx.sync = sync
|
||||
|
||||
outputs = ascendspeed_te_ops.npu_flash_attention(
|
||||
query, key, value, head_num, input_layout, pse, padding_mask,
|
||||
atten_mask, scale, keep_prob, pre_tockens, next_tockens,
|
||||
inner_precise, gen_mask_parallel, sync)
|
||||
|
||||
attention_score, softmax_max, softmax_sum, softmax_out, seed, offset, numels = outputs
|
||||
ctx.saved_for_backward(
|
||||
query, key, value, pse, padding_mask, atten_mask, attention_score, softmax_max,
|
||||
softmax_sum, softmax_out, seed, offset, numels
|
||||
)
|
||||
ctx.seed = seed
|
||||
ctx.offset = offset
|
||||
ctx.numels = numels
|
||||
ctx.prefix = prefix
|
||||
ctx.sparse_mode = sparse_mode
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_outputs):
|
||||
query, key, value, pse, padding_mask, atten_mask, attention_score, softmax_max,\
|
||||
softmax_sum, softmax_out, seed, offset, numels = ctx.saved_tensors
|
||||
results = ascendspeed_te_ops.npu_flasg_attention_grad(
|
||||
query, key, value, pse, padding_mask, atten_mask, attention_in, softmax_max,\
|
||||
softmax_sum, softmax_in = ctx.saved_tensors
|
||||
results = ascendspeed_te_ops.npu_fusion_attention_grad(
|
||||
query, key, value, grad_outputs, ctx.head_num, ctx.input_layout, pse, padding_mask, atten_mask,
|
||||
softmax_max, softmax_sum, softmax_out, attention_score, ctx.scale, ctx.keep_prob, ctx.pre_tokens,
|
||||
ctx.next_tokens, ctx.inner_precise, seed, offset, numels, ctx.gen_mask_parallel, ctx.sync)
|
||||
return results
|
||||
softmax_max, softmax_sum, softmax_in, attention_in, ctx.scale, ctx.keep_prob, ctx.pre_tokens,
|
||||
ctx.next_tokens, ctx.inner_precise, ctx.seed, ctx.offset, ctx.numels, ctx.prefix, ctx.sparse_mode,
|
||||
ctx.gen_mask_parallel, ctx.sync)
|
||||
|
||||
return results, None, None, None, None, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
ascend_flash_attention = AscendFlashAttention.apply
|
||||
|
122
setup.py
Normal file
122
setup.py
Normal file
@ -0,0 +1,122 @@
|
||||
# coding=utf-8
|
||||
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Setup for pip package."""
|
||||
|
||||
import sys
|
||||
import setuptools
|
||||
|
||||
if sys.version_info < (3,):
|
||||
raise Exception("Python 2 is not supported by AscendSpeed.")
|
||||
|
||||
__description__ = 'AscendSpeed for LLMs of Ascend'
|
||||
__version__ = '0.0.1'
|
||||
__author__ = 'Ascend'
|
||||
__long_description__ = 'AscendSpeed for LLMs of Ascend'
|
||||
__url__ = 'https://gitee.com/ascend/AscendSpeed'
|
||||
__download_url__ = 'https://gitee.com/ascend/AscendSpeed/release'
|
||||
__keywords__ = 'Ascend, langauge, deep learning, NLP'
|
||||
__license__ = 'See https://gitee.com/ascend/AscendSpeed'
|
||||
__package_name__ = 'ascendspeed'
|
||||
__contact_names__ = 'Ascend'
|
||||
|
||||
|
||||
with open("README.md", "r") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
###############################################################################
|
||||
# Dependency Loading #
|
||||
# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% #
|
||||
|
||||
|
||||
def req_file(filename):
|
||||
with open(filename) as f:
|
||||
content = f.readlines()
|
||||
return [x.strip() for x in content]
|
||||
|
||||
|
||||
install_requires = req_file("requirements.txt")
|
||||
cmd_class = {}
|
||||
exts = []
|
||||
try:
|
||||
import imp
|
||||
import os
|
||||
import glob
|
||||
from torch_npu.utils import cpp_extension
|
||||
from torch.utils.cpp_extension import BuildExtension
|
||||
|
||||
ASCEND_TOOLKIT_HOME = os.environ.get("ASCEND_TOOLKIT_HOME")
|
||||
op_files = glob.glob("ascendspeed/te/ops/csrc/*.cpp")
|
||||
ext_ops = cpp_extension.NpuExtension(
|
||||
name="ascendspeed_te_ops",
|
||||
sources=op_files,
|
||||
extra_compile_args=[
|
||||
'-Wno-sign-compare',
|
||||
'-Wno-deprecated-declarations',
|
||||
'-Wno-return-type',
|
||||
'-D__FILENAME__=\"$(notdir $(abspath $<))\"',
|
||||
'-I' + imp.find_module('torch_npu')[1] + "/include/third_party/acl/inc",
|
||||
'-I' + ASCEND_TOOLKIT_HOME + '/include/',
|
||||
],
|
||||
)
|
||||
exts.append(ext_ops)
|
||||
cmd_class = {"build_ext": BuildExtension.with_options(use_ninja=False)}
|
||||
except Exception:
|
||||
print('Can not find env : ASCEND_TOOLKIT_HOME or ATB_HOME_PATH, ops setup failed')
|
||||
|
||||
setuptools.setup(
|
||||
package_data={'ascendspeed':['ascendspeed/data/Makefile']},
|
||||
name=__package_name__,
|
||||
# Versions should comply with PEP440. For a discussion on single-sourcing
|
||||
# the version across setup.py and the project code, see
|
||||
# https://packaging.python.org/en/latest/single_source_version.html
|
||||
version=__version__,
|
||||
description=__description__,
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
# The project's main homepage.
|
||||
url=__url__,
|
||||
author=__contact_names__,
|
||||
maintainer=__contact_names__,
|
||||
# The licence under which the project is released
|
||||
license=__license__,
|
||||
classifiers=[
|
||||
'Intended Audience :: Developers',
|
||||
'Intended Audience :: Science/Research',
|
||||
'Intended Audience :: Information Technology',
|
||||
# Indicate what your project relates to
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development :: Libraries :: Python Modules',
|
||||
# Supported python versions
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'Programming Language :: Python :: 3.8',
|
||||
'Programming Language :: Python :: 3.9',
|
||||
# Additional Setting
|
||||
'Environment :: Console',
|
||||
'Natural Language :: English',
|
||||
'Operating System :: OS Independent',
|
||||
],
|
||||
python_requires='>=3.6',
|
||||
packages=setuptools.find_packages(),
|
||||
install_requires=install_requires,
|
||||
# Add in any packaged data.
|
||||
include_package_data=True,
|
||||
zip_safe=False,
|
||||
# PyPI package information.
|
||||
keywords=__keywords__,
|
||||
cmdclass=cmd_class,
|
||||
ext_modules=exts
|
||||
)
|
Loading…
Reference in New Issue
Block a user