這篇文章會展示一下如何使用現代化工具,快速把論文上的公式轉換成程式碼
1.首先到
網站註冊一個帳號並下載snipping
這個軟體是專門讓使用者利用截圖的方式把數學式子轉化成LATEX格式
我以一個論文看到的演算法為範例
好然後將LATEX複製到COLAB貼上為了驗證他的正確性
ok沒問題接下來將LATEX複製到chatgpt貼上
大功告成,測試一下吧
挖,一次就成功耶,本來以為可能會有語法錯誤什麼的
沒想到一次就成功
完整代碼:
LATEX
\begin{equation}
\text { Set } y_j= \begin{cases}r_j & \text { for terminal } \phi_{j+1} \\ r_j+\gamma \max _{a^{\prime}} Q\left(\phi_{j+1}, a^{\prime} ; \theta\right) & \text { for non-terminal } \phi_{j+1}\end{cases}
\end{equation}
pytorch程式碼
import torch
import torch.nn as nn
class QNetwork(nn.Module):
def __init__(self, input_dim, output_dim):
super(QNetwork, self).__init__()
self.fc = nn.Linear(input_dim, output_dim)
def forward(self, phi):
return self.fc(phi)
def compute_target(model, r_j, phi_j_plus_1, gamma):
with torch.no_grad():
max_q_value = model(phi_j_plus_1).max(1)[0]
y_j = r_j + gamma * max_q_value
return y_j
import torch
# 创建QNetwork模型
input_dim = 4 # 输入维度
output_dim = 2 # 输出维度
model = QNetwork(input_dim, output_dim)
# 定义输入和其他参数
r_j = torch.tensor([0.5]) # 立即奖励
phi_j_plus_1 = torch.tensor([[1.0, 2.0, 3.0, 4.0]]) # 下一个状态
gamma = 0.9 # 折扣因子
# 计算目标值
y_j = compute_target(model, r_j, phi_j_plus_1, gamma)
print(y_j)