修改矩阵乘法偏导

This commit is contained in:
李大鹏 2024-02-26 17:40:29 +08:00
parent 0d4429ff43
commit 7a0fa09491
2 changed files with 6 additions and 27 deletions

View File

@ -416,31 +416,10 @@ public class MatrixOperation {
public static Matrix matrixMulPd(Matrix errorMatrix, Matrix first, Matrix second, boolean isFirstPd) throws Exception {//对两个相乘的矩阵求偏导
Matrix matrix;
int x, y;
if (isFirstPd) {//对相乘的前矩阵进行求导
Matrix st = transPosition(second);//对矩阵2进行转置
x = first.getX();
y = first.getY();
matrix = new Matrix(x, y);
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
double errorSigma = errorMatrix.getSigmaByVector(true, i);
double xt = st.getSigmaByVector(false, j);
matrix.setNub(i, j, errorSigma * xt);
}
}
matrix = mulMatrix(errorMatrix, transPosition(second));
} else {
Matrix ft = transPosition(first);
x = second.getX();
y = second.getY();
matrix = new Matrix(x, y);
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
double errorSigma = errorMatrix.getSigmaByVector(false, j);
double at = ft.getSigmaByVector(true, i);
matrix.setNub(i, j, errorSigma * at);
}
}
matrix = mulMatrix(transPosition(first), errorMatrix);
}
return matrix;
}

View File

@ -1,7 +1,6 @@
package org.wlld.rnnJumpNerveEntity;
import org.wlld.MatrixTools.Matrix;
import org.wlld.i.ActiveFunction;
import org.wlld.i.OutBack;
@ -36,12 +35,13 @@ public class HiddenNerve extends Nerve {
} else {
double sigma = calculation(eventId);
double out = activeFunction.function(sigma);//激活函数输出数值
if (isKernelStudy) {
outNub = out;
}
if (rnnMatrix != null) {//rnn 1改输出值2查看是否需要转向
out = out + rnnMatrix.getNumber(depth, getId() - 1);
}
if (isKernelStudy) {
outNub = out;
} else {
if (!isKernelStudy) {
destroyParameter(eventId);
}
sendMessage(eventId, out, isKernelStudy, E, outBack, false, rnnMatrix, storeys, index);