Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 45 additions & 45 deletions source/source_lcao/module_gint/gint_atom.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ GintAtom::GintAtom(
RadialBlock block;
block.begin_iw = iw;
block.size = 2 * l + 1;
// The first orbital in each radial block always starts from m = 0.
block.ylm_begin = atom_->iw2_ylm[iw];
block.psi_uniform = p_psi_uniform_[iw];
block.dpsi_uniform = p_dpsi_uniform_[iw];
Expand All @@ -48,39 +47,24 @@ template <typename T>
void GintAtom::set_phi(const std::vector<Vec3d>& coords, const int stride, T* phi) const
{
const int num_mgrids = coords.size();

// orb_ does not have the member variable dr_uniform
const double dr_uniform = orb_->PhiLN(0, 0).dr_uniform;

// store the spherical harmonics
// it's outside the loop to reduce the vector allocation overhead
std::vector<double> ylma;
const auto* blocks = radial_blocks_.data();
const int num_blocks = radial_blocks_.size();

for(int im = 0; im < num_mgrids; im++)
{
const Vec3d& coord = coords[im];
// 1e-9 is to avoid division by zero
const double dist = coord.norm() < 1e-9 ? 1e-9 : coord.norm();
if(dist > orb_->getRcut())
{
// if the distance is larger than the cutoff radius,
// the wave function values are all zeros
ModuleBase::GlobalFunc::ZEROS(phi + im * stride, atom_->nw);
}
else
{
// spherical harmonics
// TODO: vectorize the sph_harm function,
// the vectorized function can be called once for all meshgrids in a biggrid
ModuleBase::Ylm::sph_harm(atom_->nwl, coord.x/dist, coord.y/dist, coord.z/dist, ylma);
// interpolation

// these parameters are related to interpolation
// because once the distance from atom to grid point is known,
// we can obtain the parameters for interpolation and
// store them first! these operations can save lots of efforts.
const double position = dist / dr_uniform;
const int ip = static_cast<int>(position);
const double dx = position - ip;
Expand All @@ -103,9 +87,6 @@ void GintAtom::set_phi(const std::vector<Vec3d>& coords, const int stride, T* ph

const int begin_iw = block.begin_iw;
const int end_iw = begin_iw + block.size;
// Within one (L, N) block, m runs consecutively, so we can walk
// the Ylm buffer linearly instead of reading atom_->iw2_ylm[iw]
// for every orbital in the hot loop.
int idx_lm = block.ylm_begin;
for (int iw = begin_iw; iw < end_iw; ++iw, ++idx_lm)
{
Expand All @@ -121,25 +102,36 @@ void GintAtom::set_phi_dphi(
const std::vector<Vec3d>& coords, const int stride,
T* phi, T* dphi_x, T* dphi_y, T* dphi_z) const
{
if (phi != nullptr) {
phi = (T*)__builtin_assume_aligned(phi, 64);
}
if (dphi_x != nullptr) {
dphi_x = (T*)__builtin_assume_aligned(dphi_x, 64);
dphi_y = (T*)__builtin_assume_aligned(dphi_y, 64);
dphi_z = (T*)__builtin_assume_aligned(dphi_z, 64);
}
const int num_mgrids = coords.size();

// orb_ does not have the member variable dr_uniform
const double dr_uniform = orb_->PhiLN(0, 0).dr_uniform;

const int nylm = std::pow(atom_->nwl + 1, 2);
std::vector<double> rly(nylm);
std::vector<double> grly(nylm * 3);

// 展平为一维连续内存
std::vector<double> grly_data(nylm * 3);

// 构造代理二维指针数组以适配底层接口
std::vector<double*> grly_ptrs(nylm);
for(int i = 0; i < nylm; ++i) {
grly_ptrs[i] = &grly_data[i * 3];
}

for(int im = 0; im < num_mgrids; im++)
{
const Vec3d& coord = coords[im];
// 1e-9 is to avoid division by zero
const double dist = coord.norm() < 1e-9 ? 1e-9 : coord.norm();

if(dist > orb_->getRcut())
{
// if the distance is larger than the cutoff radius,
// the wave function values are all zeros
if(phi != nullptr)
{
ModuleBase::GlobalFunc::ZEROS(phi + im * stride, atom_->nw);
Expand All @@ -150,12 +142,9 @@ void GintAtom::set_phi_dphi(
}
else
{
// spherical harmonics
// TODO: vectorize the sph_harm function,
// the vectorized function can be called once for all meshgrids in a biggrid
ModuleBase::Ylm::grad_rl_sph_harm(atom_->nwl, coord.x, coord.y, coord.z, rly.data(), grly.data());
// 使用代理指针数组传入,底层函数会将结果写入 grly_data 中
ModuleBase::Ylm::grad_rl_sph_harm(atom_->nwl, coord.x, coord.y, coord.z, rly.data(), grly_ptrs.data());

// interpolation
const double position = dist / dr_uniform;
const int ip = static_cast<int>(position);
const double x0 = position - ip;
Expand All @@ -166,43 +155,53 @@ void GintAtom::set_phi_dphi(
const double x03 = x0 * x3 / 2;

double tmp, dtmp;

// 对每个轨道 iw 进行计算
for(int iw = 0; iw < atom_->nw; ++iw)
{
// this is a new 'l', we need 1D orbital wave
// function from interpolation method.
if(atom_->iw2_new[iw])
{
auto psi_uniform = p_psi_uniform_[iw];
auto dpsi_uniform = p_dpsi_uniform_[iw];
// use Polynomia Interpolation method to get the
// wave functions

tmp = x12 * (psi_uniform[ip] * x3 + psi_uniform[ip + 3] * x0)
+ x03 * (psi_uniform[ip + 1] * x2 - psi_uniform[ip + 2] * x1);

dtmp = x12 * (dpsi_uniform[ip] * x3 + dpsi_uniform[ip + 3] * x0)
+ x03 * (dpsi_uniform[ip + 1] * x2 - dpsi_uniform[ip + 2] * x1);
} // new l is used.
}

// get the 'l' of this localized wave function
const int ll = atom_->iw2l[iw];
const int idx_lm = atom_->iw2_ylm[iw];

const double rl = pow_int(dist, ll);
double rl = 1.0;
switch (ll) {
case 4: rl = dist * dist * dist * dist; break;
case 3: rl = dist * dist * dist; break;
case 2: rl = dist * dist; break;
case 1: rl = dist; break;
case 0: rl = 1.0; break;
default: rl = pow_int(dist, ll);
}

const double tmprl = tmp / rl;
const double tmpdphi_rly = (dtmp - tmp * ll / dist) / rl / dist;

// 3D wave functions
// 移除错误的内部 im 循环,直接对当前网格点(im)和当前轨道(iw)赋值
if(phi != nullptr)
{
phi[im * stride + iw] = tmprl * rly[idx_lm];
}

// derivative of wave functions with respect to atom positions.
const double tmpdphi_rly = (dtmp - tmp * ll / dist) / rl * rly[idx_lm] / dist;

dphi_x[im * stride + iw] = tmpdphi_rly * coord.x + tmprl * grly[idx_lm*3];
dphi_y[im * stride + iw] = tmpdphi_rly * coord.y + tmprl * grly[idx_lm*3 + 1];
dphi_z[im * stride + iw] = tmpdphi_rly * coord.z + tmprl * grly[idx_lm*3 + 2];
if(dphi_x != nullptr)
{
double tmpdphi_rly_val = tmpdphi_rly * rly[idx_lm];

// 使用一维数组偏移寻址
dphi_x[im * stride + iw] = tmpdphi_rly_val * coord.x + tmprl * grly_data[idx_lm * 3 + 0];
dphi_y[im * stride + iw] = tmpdphi_rly_val * coord.y + tmprl * grly_data[idx_lm * 3 + 1];
dphi_z[im * stride + iw] = tmpdphi_rly_val * coord.z + tmprl * grly_data[idx_lm * 3 + 2];
}
}
}
}
Expand All @@ -214,4 +213,5 @@ template void GintAtom::set_phi(const std::vector<Vec3d>& coords, const int stri
template void GintAtom::set_phi(const std::vector<Vec3d>& coords, const int stride, std::complex<double>* phi) const;
template void GintAtom::set_phi_dphi(const std::vector<Vec3d>& coords, const int stride, float* phi, float* dphi_x, float* dphi_y, float* dphi_z) const;
template void GintAtom::set_phi_dphi(const std::vector<Vec3d>& coords, const int stride, double* phi, double* dphi_x, double* dphi_y, double* dphi_z) const;
}

} // namespace ModuleGint