#include <iostream> #include <vector> using namespace std;
class Matrix {
public: Matrix(int);
Matrix(const Matrix&);
Matrix& operator=(const Matrix&); int operator()(int,int) const; int& operator()(int,int);
Matrix& operator +=(const Matrix&); Matrix& operator -=(const Matrix&); Matrix& operator *=(int);
Matrix& operator *=(const Matrix&); Matrix GetQuarter(int) const ;//得到1/4个矩阵
Matrix& SetQuarter(const Matrix&,int);//设置1/4个矩阵 int Side() const;//矩阵的行/列 void Show() const;//打印矩阵 private:
void Malloc(int);//设置一个仿二维数组的大小 const int MN;//矩阵的行/列
vector< vector<int> > Data;//矩阵的数据 };
Matrix operator +(const Matrix&,const Matrix&);//全局函数:矩阵相加 Matrix operator -(const Matrix&,const Matrix&);//全局函数:矩阵相减 Matrix operator *(int,const Matrix&);//全局函数:整数乘以矩阵,在本程序中没用到
Matrix operator *(const Matrix&,int);//全局函数:矩阵乘以整数,在本程序中没用到
Matrix operator *(const Matrix&,const Matrix&);//全局函数:矩阵相乘
void Matrix::Malloc(int mn)//设定一个仿二维数组,大小为mn*mn; {
Data.resize(mn);
for(int i=0;i<mn;++i) Data[i].resize(mn);
}
Matrix::Matrix(int mn):MN(mn) {
Malloc(MN); }
Matrix::Matrix(const Matrix& rhs):MN(rhs.MN)//拷贝构造 {
Malloc(MN);
for(int i=0;i<MN;++i) for(int j=0;j<MN;++j) Data[i][j]=rhs.Data[i][j]; }
Matrix& Matrix::operator=(const Matrix& rhs)//矩阵赋值 {
if(MN!=rhs.MN) throw; for(int i=0;i<MN;++i) for(int j=0;j<MN;++j) Data[i][j]=rhs.Data[i][j]; return *this; }
int Matrix::operator()(int m,int n) const//得到矩阵的某个元素,坐标从0开始 { return Data[m][n]; }
int& Matrix::operator()(int m,int n)//得到矩阵的某个元素,坐标从0开始,可以赋值
{ return Data[m][n]; }
Matrix& Matrix::operator +=(const Matrix& rhs)//类公开函数,A+=B; {
if(MN!=rhs.MN) throw; for(int i=0;i<MN;++i) for(int j=0;j<MN;++j) Data[i][j]+=rhs.Data[i][j]; return *this; }
Matrix& Matrix::operator -=(const Matrix& rhs)//类公开函数,A-=B; {
if(MN!=rhs.MN) throw; for(int i=0;i<MN;++i) for(int j=0;j<MN;++j) Data[i][j]-=rhs.Data[i][j]; return *this; }
Matrix& Matrix::operator *=(int Num)//类公开函数,A*=i; {
for(int i=0;i<MN;++i) for(int j=0;j<MN;++j) Data[i][j]*=Num; return *this; }
Matrix Matrix::GetQuarter(int Pos) const//得到1/4个矩阵,Pos=0代表左上,Pos=1代表右上... {
int X,Y;
switch(Pos%4) {
case 0:X=Y=0;break;
case 1:X=0;Y=MN/2;break; case 2:X=MN/2;Y=0;break; case 3:X=Y=MN/2; }
Matrix T(MN/2);
for(int i=0;i<T.MN;++i) for(int j=0;j<T.MN;++j) T.Data[i][j]=Data[i
+X][j+Y]; return T; }
Matrix& Matrix::SetQuarter(const Matrix& rhs,int Pos)//把rhs的值拷贝为自身的1/4个矩阵,Pos含义同上 {
int X,Y;
switch(Pos%4) {
case 0:X=Y=0;break;
case 1:X=0;Y=MN/2;break; case 2:X=MN/2;Y=0;break; case 3:X=Y=MN/2; }
for(int i=0;i<rhs.MN;++i) for(int j=0;j<rhs.MN;++j) Data[i+X][j+Y]=rhs.Data[i][j]; return *this; }
int Matrix::Side() const//得到行/列 {
return MN; }
Matrix& Matrix::operator *=(const Matrix& rhs)//类公开函数 A*=B; {
*this= *this * rhs;//调用全局函数:矩阵相乘 return *this; }
void Matrix::Show() const//打印矩阵
{ cout<<"Display Matrix:"<<endl; for(int i=0;i<MN;++i){ for(int j=0;j<MN;++j)
cout<<Data[i][j]<<' '; cout<<endl; } }
Matrix operator +(const Matrix& rhs1,const Matrix& rhs2)//全局函数:矩阵相加 {
Matrix T(rhs1);
return T+=rhs2;//调用类公开函数+= }
Matrix operator -(const Matrix& rhs1,const Matrix& rhs2)//全局函数:矩阵相减 {
Matrix T(rhs1);
return T-=rhs2;//调用类公开函数-= }
Matrix operator *(int Num,const Matrix& rhs)//全局函数:整数乘以矩阵 {
Matrix T(rhs);
return T*=Num;//调用类公开函数*= }
Matrix operator *(const Matrix& rhs,int Num)//全局函数:矩阵乘以整数 {
Matrix T(rhs);
return T*=Num;//调用类公开函数*= }
Matrix operator *(const Matrix& rhs1,const Matrix& rhs2)//全局函数,矩阵相乘
{ if(rhs1.Side()!=rhs2.Side()) throw;
if(rhs1.Side()==2){//行/列为2,按照常规方法计算 Matrix T(2);
T(0,0)=rhs1(0,0)*rhs2(0,0)+rhs1(0,1)*rhs2(1,0);//A11B11+A12B21; T(0,1)=rhs1(0,0)*rhs2(0,1)+rhs1(0,1)*rhs2(1,1);//A11B12+A12B22 T(1,0)=rhs1(1,0)*rhs2(0,0)+rhs1(1,1)*rhs2(1,0);//A21B11+A22B21 T(1,1)=rhs1(1,0)*rhs2(0,1)+rhs1(1,1)*rhs2(1,1);//A21B12+A22B22 return T; };
Matrix A11(rhs1.GetQuarter(0));//第一个矩阵的左上1/4矩阵 Matrix A12(rhs1.GetQuarter(1));//第一个矩阵的右上1/4矩阵 Matrix A21(rhs1.GetQuarter(2));//第一个矩阵的左下1/4矩阵 Matrix A22(rhs1.GetQuarter(3));//第一个矩阵的右下1/4矩阵 Matrix B11(rhs2.GetQuarter(0));//第二个矩阵的左上1/4矩阵 Matrix B12(rhs2.GetQuarter(1));//第二个矩阵的右上1/4矩阵 Matrix B21(rhs2.GetQuarter(2));//第二个矩阵的左下1/4矩阵 Matrix B22(rhs2.GetQuarter(3));//第二个矩阵的右下1/4矩阵 Matrix M1(A11*(B12-B22));//递归调用全局函数,矩阵相乘 Matrix M2((A11+A12)*B22);//递归调用全局函数,矩阵相乘 Matrix M3((A21+A22)*B11);//递归调用全局函数,矩阵相乘 Matrix M4(A22*(B21-B11));//递归调用全局函数,矩阵相乘
Matrix M5((A11+A22)*(B11+B22));//递归调用全局函数,矩阵相乘 Matrix M6((A12-A22)*(B21+B22));//递归调用全局函数,矩阵相乘 Matrix M7
((A11-A21)*(B11+B12));//递归调用全局函数,矩阵相乘 Matrix C11(M5+M4-M2+M6);//调用全局函数,矩阵相加/减 Matrix C12(M1+M2);//调用全局函数,矩阵相加
Matrix C21(M3+M4);//调用全局函数,矩阵相加
Matrix C22(M5+M1-M3-M7);//调用全局函数,矩阵相加/减
Matrix T(rhs1.Side());//返回的矩阵 //设置C11-C22为T的四个小矩阵
T.SetQuarter(C11,0).SetQuarter(C12,1).SetQuarter(C21,2).SetQuarter(C22,3); return T; }
bool Is2Pow(int i)//判断i是否是2的n次方 { if(i<2) return false; while(i>2){
if(i%2) return false; i/=2; }
return i==2 ? true:false; }
int main() { int M;
cout<<"Input two matrixes[M*M],and culculate multiply with Strassen!"<<endl; cout<<"M="; cin>>M;
if(Is2Pow(M)==false){
cout<<"Error:M should equal 2^n"; return 1; }
cout<<"Input Matrix A["<<M<<"*"<<M<<"]:"<<endl;
Matrix A(M);
for(int i=0;i<M;++i) for(int j=0;j<M;++j) cin>>A(i,j);
cout<<"Input Matrix B["<<M<<"*"<<M<<"]:"<<endl;
Matrix B(M);
for(int i=0;i<M;++i) for(int j=0;j<M;++j) cin>>B(i,j); Matrix AB(A*B);
cout<<"A*B"<<endl; AB.Show();
Matrix BA(B*A);
cout<<"B*A"<<endl; BA.Show(); return 0; }
因篇幅问题不能全部显示,请点此查看更多更全内容