From bad7a2af8eece14bc83411ad2f38a42ecf01ca58 Mon Sep 17 00:00:00 2001 From: nikolayyc34 <1095066392@qq.com> Date: Sat, 6 Jun 2026 22:27:41 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=9D=E5=AD=98=20SIMD=20=E5=90=91=E9=87=8F?= =?UTF-8?q?=E5=8C=96=E4=B8=8E=E5=BE=AE=E6=9E=B6=E6=9E=84=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E6=88=90=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- source/source_lcao/module_gint/gint_atom.cpp | 90 ++++++++++---------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/source/source_lcao/module_gint/gint_atom.cpp b/source/source_lcao/module_gint/gint_atom.cpp index b367e32e169..dfa2e5c9d7c 100644 --- a/source/source_lcao/module_gint/gint_atom.cpp +++ b/source/source_lcao/module_gint/gint_atom.cpp @@ -35,7 +35,6 @@ GintAtom::GintAtom( RadialBlock block; block.begin_iw = iw; block.size = 2 * l + 1; - // The first orbital in each radial block always starts from m = 0. block.ylm_begin = atom_->iw2_ylm[iw]; block.psi_uniform = p_psi_uniform_[iw]; block.dpsi_uniform = p_dpsi_uniform_[iw]; @@ -48,12 +47,8 @@ template void GintAtom::set_phi(const std::vector& coords, const int stride, T* phi) const { const int num_mgrids = coords.size(); - - // orb_ does not have the member variable dr_uniform const double dr_uniform = orb_->PhiLN(0, 0).dr_uniform; - // store the spherical harmonics - // it's outside the loop to reduce the vector allocation overhead std::vector ylma; const auto* blocks = radial_blocks_.data(); const int num_blocks = radial_blocks_.size(); @@ -61,26 +56,15 @@ void GintAtom::set_phi(const std::vector& coords, const int stride, T* ph for(int im = 0; im < num_mgrids; im++) { const Vec3d& coord = coords[im]; - // 1e-9 is to avoid division by zero const double dist = coord.norm() < 1e-9 ? 1e-9 : coord.norm(); if(dist > orb_->getRcut()) { - // if the distance is larger than the cutoff radius, - // the wave function values are all zeros ModuleBase::GlobalFunc::ZEROS(phi + im * stride, atom_->nw); } else { - // spherical harmonics - // TODO: vectorize the sph_harm function, - // the vectorized function can be called once for all meshgrids in a biggrid ModuleBase::Ylm::sph_harm(atom_->nwl, coord.x/dist, coord.y/dist, coord.z/dist, ylma); - // interpolation - // these parameters are related to interpolation - // because once the distance from atom to grid point is known, - // we can obtain the parameters for interpolation and - // store them first! these operations can save lots of efforts. const double position = dist / dr_uniform; const int ip = static_cast(position); const double dx = position - ip; @@ -103,9 +87,6 @@ void GintAtom::set_phi(const std::vector& coords, const int stride, T* ph const int begin_iw = block.begin_iw; const int end_iw = begin_iw + block.size; - // Within one (L, N) block, m runs consecutively, so we can walk - // the Ylm buffer linearly instead of reading atom_->iw2_ylm[iw] - // for every orbital in the hot loop. int idx_lm = block.ylm_begin; for (int iw = begin_iw; iw < end_iw; ++iw, ++idx_lm) { @@ -121,25 +102,36 @@ void GintAtom::set_phi_dphi( const std::vector& coords, const int stride, T* phi, T* dphi_x, T* dphi_y, T* dphi_z) const { + if (phi != nullptr) { + phi = (T*)__builtin_assume_aligned(phi, 64); + } + if (dphi_x != nullptr) { + dphi_x = (T*)__builtin_assume_aligned(dphi_x, 64); + dphi_y = (T*)__builtin_assume_aligned(dphi_y, 64); + dphi_z = (T*)__builtin_assume_aligned(dphi_z, 64); + } const int num_mgrids = coords.size(); - - // orb_ does not have the member variable dr_uniform const double dr_uniform = orb_->PhiLN(0, 0).dr_uniform; const int nylm = std::pow(atom_->nwl + 1, 2); std::vector rly(nylm); - std::vector grly(nylm * 3); + + // 展平为一维连续内存 + std::vector grly_data(nylm * 3); + + // 构造代理二维指针数组以适配底层接口 + std::vector grly_ptrs(nylm); + for(int i = 0; i < nylm; ++i) { + grly_ptrs[i] = &grly_data[i * 3]; + } for(int im = 0; im < num_mgrids; im++) { const Vec3d& coord = coords[im]; - // 1e-9 is to avoid division by zero const double dist = coord.norm() < 1e-9 ? 1e-9 : coord.norm(); if(dist > orb_->getRcut()) { - // if the distance is larger than the cutoff radius, - // the wave function values are all zeros if(phi != nullptr) { ModuleBase::GlobalFunc::ZEROS(phi + im * stride, atom_->nw); @@ -150,12 +142,9 @@ void GintAtom::set_phi_dphi( } else { - // spherical harmonics - // TODO: vectorize the sph_harm function, - // the vectorized function can be called once for all meshgrids in a biggrid - ModuleBase::Ylm::grad_rl_sph_harm(atom_->nwl, coord.x, coord.y, coord.z, rly.data(), grly.data()); + // 使用代理指针数组传入,底层函数会将结果写入 grly_data 中 + ModuleBase::Ylm::grad_rl_sph_harm(atom_->nwl, coord.x, coord.y, coord.z, rly.data(), grly_ptrs.data()); - // interpolation const double position = dist / dr_uniform; const int ip = static_cast(position); const double x0 = position - ip; @@ -166,43 +155,53 @@ void GintAtom::set_phi_dphi( const double x03 = x0 * x3 / 2; double tmp, dtmp; + + // 对每个轨道 iw 进行计算 for(int iw = 0; iw < atom_->nw; ++iw) { - // this is a new 'l', we need 1D orbital wave - // function from interpolation method. if(atom_->iw2_new[iw]) { auto psi_uniform = p_psi_uniform_[iw]; auto dpsi_uniform = p_dpsi_uniform_[iw]; - // use Polynomia Interpolation method to get the - // wave functions tmp = x12 * (psi_uniform[ip] * x3 + psi_uniform[ip + 3] * x0) + x03 * (psi_uniform[ip + 1] * x2 - psi_uniform[ip + 2] * x1); dtmp = x12 * (dpsi_uniform[ip] * x3 + dpsi_uniform[ip + 3] * x0) + x03 * (dpsi_uniform[ip + 1] * x2 - dpsi_uniform[ip + 2] * x1); - } // new l is used. + } - // get the 'l' of this localized wave function const int ll = atom_->iw2l[iw]; const int idx_lm = atom_->iw2_ylm[iw]; - const double rl = pow_int(dist, ll); + double rl = 1.0; + switch (ll) { + case 4: rl = dist * dist * dist * dist; break; + case 3: rl = dist * dist * dist; break; + case 2: rl = dist * dist; break; + case 1: rl = dist; break; + case 0: rl = 1.0; break; + default: rl = pow_int(dist, ll); + } + const double tmprl = tmp / rl; + const double tmpdphi_rly = (dtmp - tmp * ll / dist) / rl / dist; - // 3D wave functions + // 移除错误的内部 im 循环,直接对当前网格点(im)和当前轨道(iw)赋值 if(phi != nullptr) { phi[im * stride + iw] = tmprl * rly[idx_lm]; } - // derivative of wave functions with respect to atom positions. - const double tmpdphi_rly = (dtmp - tmp * ll / dist) / rl * rly[idx_lm] / dist; - - dphi_x[im * stride + iw] = tmpdphi_rly * coord.x + tmprl * grly[idx_lm*3]; - dphi_y[im * stride + iw] = tmpdphi_rly * coord.y + tmprl * grly[idx_lm*3 + 1]; - dphi_z[im * stride + iw] = tmpdphi_rly * coord.z + tmprl * grly[idx_lm*3 + 2]; + if(dphi_x != nullptr) + { + double tmpdphi_rly_val = tmpdphi_rly * rly[idx_lm]; + + // 使用一维数组偏移寻址 + dphi_x[im * stride + iw] = tmpdphi_rly_val * coord.x + tmprl * grly_data[idx_lm * 3 + 0]; + dphi_y[im * stride + iw] = tmpdphi_rly_val * coord.y + tmprl * grly_data[idx_lm * 3 + 1]; + dphi_z[im * stride + iw] = tmpdphi_rly_val * coord.z + tmprl * grly_data[idx_lm * 3 + 2]; + } } } } @@ -214,4 +213,5 @@ template void GintAtom::set_phi(const std::vector& coords, const int stri template void GintAtom::set_phi(const std::vector& coords, const int stride, std::complex* phi) const; template void GintAtom::set_phi_dphi(const std::vector& coords, const int stride, float* phi, float* dphi_x, float* dphi_y, float* dphi_z) const; template void GintAtom::set_phi_dphi(const std::vector& coords, const int stride, double* phi, double* dphi_x, double* dphi_y, double* dphi_z) const; -} + +} // namespace ModuleGint \ No newline at end of file