矩陣相乘是一個經典的演算法問題,經常出現在技術面試中,用來評估你對多維陣列的操作和程式設計邏輯的掌握。雖然Python有NumPy這樣的庫可以輕鬆實現矩陣相乘,但面試通常要求你手寫基礎實現,展示對原理的理解。本文將以標準矩陣相乘為例,介紹其基礎概念、Python程式碼、時間與空間複雜度,並提供面試應答技巧。
什麼是矩陣相乘?
矩陣相乘是將兩個矩陣(二維陣列)相乘,得到一個新矩陣。假設你有兩個矩陣:
- 矩陣 A 是 m x n(m 行 n 列)。
- 矩陣 B 是 n x p(n 行 p 列)。
- 結果矩陣 C 是 m x p。
C
的每個元素 C[i][j]
是 A
的第 i
行與 B
的第 j
列的點積(對應元素相乘後求和)。簡單比喻
想像你在製作一份成績表:
- A 是「學生 x 科目」的分數表。
- B 是「科目 x 權重」的權重表。
- C 是「學生 x 加權分數」的結果表。 每個加權分數是學生在各科目的分數乘以對應權重,然後加總。
範例
輸入:
A = [[1, 2],
[3, 4]]
B = [[5, 6],
[7, 8]]
輸出:
C = [[19, 22],
[43, 50]]
計算過程:
C[0][0] = A[0][0] * B[0][0] + A[0][1] * B[1][0] = 1*5 + 2*7 = 19
C[0][1] = A[0][0] * B[0][1] + A[0][1] * B[1][1] = 1*6 + 2*8 = 22
C[1][0] = A[1][0] * B[0][0] + A[1][1] * B[1][0] = 3*5 + 4*7 = 43
C[1][1] = A[1][0] * B[0][1] + A[1][1] * B[1][1] = 3*6 + 4*8 = 50
解法:標準矩陣相乘
矩陣相乘的標準方法使用三層巢狀迴圈,計算每個 C[i][j]
:
- 初始化結果矩陣:
- 創建 m x p 的矩陣 C,初始值為 0。
- 計算每個元素:
- 外層迴圈遍歷 C 的行(i 從 0 到 m-1)。
- 中層迴圈遍歷 C 的列(j 從 0 到 p-1)。
- 內層迴圈計算點積(k 從 0 到 n-1),累加 A[i][k] * B[k][j]。
- 返回結果: 返回填充好的矩陣 C。
簡單比喻
想像你在填一張表格,每個格子(C[i][j]
)需要把 A
的一行和 B
的一列對應數字相乘,然後加起來。三層迴圈就像你在逐行逐列填表,內層迴圈負責算每個格子的值。
程式碼範例
以下是Python實現,簡單且易於面試手寫:
def matrix_multiply(A, B):
# 獲取矩陣維度
m = len(A) # A 的行數
n = len(A[0]) # A 的列數,B 的行數
p = len(B[0]) # B 的列數
# 初始化結果矩陣 C,大小為 m x p
C = [[0 for _ in range(p)] for _ in range(m)]
# 計算矩陣乘積
for i in range(m): # 遍歷 C 的行
for j in range(p): # 遍歷 C 的列
for k in range(n): # 計算點積
C[i][j] += A[i][k] * B[k][j]
return C
# 測試程式碼
A = [[1, 2],
[3, 4]]
B = [[5, 6],
[7, 8]]
result = matrix_multiply(A, B)
for row in result:
print(row) # 輸出: [19, 22]
# [43, 50]
程式碼解釋
- 輸入:
- A 是 m x n 矩陣,B 是 n x p 矩陣。
- 假設輸入有效(len(A[0]) == len(B))。
- 初始化:
- 使用列表推導式創建 m x p 的結果矩陣 C,初始值為 0。
- 計算:
- 三層迴圈:
- i 遍歷 C 的行(0 到 m-1)。
- j 遍歷 C 的列(0 到 p-1)。
- k 遍歷 A 的列和 B 的行(0 到 n-1)。
- 計算 C[i][j] += A[i][k] * B[k][j]。
- 輸出:
- 返回結果矩陣 C。
- Python優勢:無需手動管理記憶體,列表操作直觀。
時間與空間複雜度
- 時間複雜度:O(m * n * p)
- 三層迴圈分別迭代 m、p、n 次。
- 這是標準矩陣相乘的複雜度,進階方法(如Strassen演算法)可降低到 O(n^2.807),但面試中很少要求。
- 空間複雜度:O(m * p)
- 用於儲存結果矩陣 C。
- 如果不計輸出,額外空間為 O(1)(僅用幾個迴圈變數)。
面試應答策略
- 理解題目:
- 確認矩陣維度是否有效(A 的列數是否等於 B 的行數)。
- 詢問是否有特殊情況(空矩陣、單行/列矩陣)。
- 確認是否需要檢查輸入(例如,矩陣是否為空)。
- 問清楚是否允許使用庫(如NumPy),通常面試要求手寫。
- 講解思路:
- 用簡單比喻(成績表、填表格)讓面試官明白你的思考。
- 畫圖展示矩陣維度和點積計算(例如,C[0][0] = A[0][0] * B[0][0] + A[0][1] * B[1][0])。
- 提到三層迴圈的邏輯:i 控制行,j 控制列,k 計算點積。
- 說明時間複雜度(O(m * n * p))和空間複雜度(O(m * p))。
- 程式碼實現:
- 寫乾淨的程式碼,使用有意義的變數名(m, n, p 而不是 x, y, z)。
- 註釋關鍵步驟,特別是迴圈的作用(行、列、點積)。
- 處理邊界情況(可加入輸入檢查):
if not A or not B or len(A[0]) != len(B):
return []
- 進階討論:
- 如果面試官問到優化,提到:
- 快取友好:調整迴圈順序(例如,i, k, j)以提高快取命中率。
- 並行化:將矩陣分塊,交給多執行緒或GPU處理。
- 進階演算法:Strassen演算法,但強調實務中標準方法更常見(因簡單且易維護)。
- 討論實際應用:
- 矩陣相乘用於機器學習(神經網路)、圖形學(變換矩陣)、資料分析。
- Python中,實務會用NumPy(np.dot)加速計算。
- 提到Python與C的區別:Python列表操作簡單,但C需要手動分配二維陣列記憶體。
- 常見追問:
- 「如果矩陣很大怎麼辦?」回答:分塊處理,減少記憶體使用;使用稀疏矩陣(若適用);考慮分散式計算。
- 「如何處理稀疏矩陣?」回答:僅儲存非零元素(用字典或壓縮格式如CSR),只計算非零項的乘積。
- 「如何處理浮點數精度?」回答:使用高精度數學庫(如 decimal),或在最後四捨五入。
- 「Python與C++實現的區別?」回答:Python無需手動記憶體管理,列表推導式簡化初始化;C++需用指標或 std::vector,小心記憶體洩漏。
- 模擬面試準備:
- 在LeetCode練習矩陣相關題目(#73 Set Matrix Zeroes, #54 Spiral Matrix)。
- 手寫程式碼,模擬白板環境,練習邊寫邊講解。
- 測試邊界情況(空矩陣、1x1矩陣、不相容維度)。
- 熟悉Python的列表操作,確保不誤用賦值或索引。
進階範例:加入輸入檢查
以下是一個更穩健的實現,包含輸入驗證,展示面試中如何處理邊界情況:
def matrix_multiply(A, B):
# 檢查輸入
if not A or not B or not A[0] or not B[0]:
return []
if len(A[0]) != len(B):
return []
# 獲取矩陣維度
m = len(A) # A 的行數
n = len(A[0]) # A 的列數,B 的行數
p = len(B[0]) # B 的列數
# 初始化結果矩陣
C = [[0 for _ in range(p)] for _ in range(m)]
# 計算矩陣乘積
for i in range(m):
for j in range(p):
for k in range(n):
C[i][j] += A[i][k] * B[k][j]
return C
# 測試
A = [[1, 2],
[3, 4]]
B = [[5, 6],
[7, 8]]
result = matrix_multiply(A, B)
for row in result:
print(row) # 輸出: [19, 22]
# [43, 50]
# 測試邊界
A_empty = []
B_invalid = [[1], [2]]
print(matrix_multiply(A_empty, B)) # 輸出: []
print(matrix_multiply(A, B_invalid)) # 輸出: []
說明
- 輸入檢查:
- 檢查空矩陣(A 或 B 為空,或包含空行)。
- 檢查維度是否相容(len(A[0]) != len(B))。
- 穩健性:返回空列表表示無效輸入,符合面試對邊界處理的要求。
- 清晰性:程式碼結構分明,註釋說明每個步驟。
結語
矩陣相乘是面試中的基礎題目,掌握標準三層迴圈實現能展示你對陣列操作和演算法的理解。使用Python時,重點在於寫出乾淨的程式碼並處理邊界情況。練習時,專注於講解邏輯(用成績表比喻)、分析複雜度,並準備回答優化問題。模擬面試時,練習邊寫邊講解,確保程式碼無誤且表達清晰。其他矩陣題目(如旋轉矩陣、稀疏矩陣)也可能出現,建議在LeetCode上多練幾題。祝大家面試順利!