@@ -100,7 +100,7 @@ void CubTypeRule::registerMatcher(ast_matchers::MatchFinder &MF) {
100100 " cub::ArgIndexInputIterator" , " cub::DiscardOutputIterator" ,
101101 " cub::DoubleBuffer" , " cub::NullType" , " cub::ArgMax" , " cub::ArgMin" ,
102102 " cub::BlockRadixSort" , " cub::BlockExchange" , " cub::BlockLoad" ,
103- " cub::BlockStore" );
103+ " cub::BlockStore" , " cub::BlockShuffle " );
104104 };
105105
106106 MF.addMatcher (
@@ -158,15 +158,16 @@ void CubDeviceLevelRule::runRule(
158158void CubMemberCallRule::registerMatcher (ast_matchers::MatchFinder &MF) {
159159 MF.addMatcher (
160160 cxxMemberCallExpr (
161- allOf (on (hasType (hasCanonicalType (qualType (hasDeclaration (namedDecl (
162- hasAnyName (" cub::ArgIndexInputIterator" ,
163- " cub::BlockRadixSort" , " cub::BlockExchange" ,
164- " cub::BlockLoad" , " cub::BlockStore" ))))))),
161+ allOf (on (hasType (hasCanonicalType (
162+ qualType (hasDeclaration (namedDecl (hasAnyName (
163+ " cub::ArgIndexInputIterator" , " cub::BlockRadixSort" ,
164+ " cub::BlockExchange" , " cub::BlockLoad" ,
165+ " cub::BlockStore" , " cub::BlockShuffle" ))))))),
165166 callee (cxxMethodDecl (hasAnyName (
166167 " normalize" , " Sort" , " SortDescending" , " BlockedToStriped" ,
167168 " StripedToBlocked" , " ScatterToBlocked" , " ScatterToStriped" ,
168169 " SortBlockedToStriped" , " SortDescendingBlockedToStriped" ,
169- " Load" , " Store" )))))
170+ " Load" , " Store" , " Offset " , " Rotate " , " Up " , " Down " )))))
170171 .bind (" memberCall" ),
171172 this );
172173
@@ -253,13 +254,17 @@ void CubMemberCallRule::runRule(
253254 Name == " BlockedToStriped" || Name == " StripedToBlocked" ||
254255 Name == " StripedToBlocked" || Name == " ScatterToBlocked" ||
255256 Name == " ScatterToStriped" ;
256- if (isBlockRadixSort || isBlockExchange || Name == " Load" ||
257- Name == " Store" ) {
257+ bool isBlockShuffle =
258+ Name == " Offset" || Name == " Rotate" || Name == " Up" || Name == " Down" ;
259+ if (isBlockRadixSort || isBlockExchange || isBlockShuffle ||
260+ Name == " Load" || Name == " Store" ) {
258261 std::string HelpFuncName;
259262 if (isBlockRadixSort)
260263 HelpFuncName = " group_radix_sort" ;
261264 else if (isBlockExchange)
262265 HelpFuncName = " exchange" ;
266+ else if (isBlockShuffle)
267+ HelpFuncName = " group_shuffle" ;
263268 else if (Name == " Load" )
264269 HelpFuncName = " group_load" ;
265270 else if (Name == " Store" )
@@ -273,20 +278,36 @@ void CubMemberCallRule::runRule(
273278 auto *ClassSpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(
274279 CanTy->getAs <RecordType>()->getDecl ());
275280 const auto &ValueTyArg = ClassSpecDecl->getTemplateArgs ()[0 ];
276- const auto &ItemsPreThreadArg = ClassSpecDecl-> getTemplateArgs ()[ 2 ];
281+
277282 ValueTyArg.getAsType ().getAsString ();
278283 std::string Fn;
279284 llvm::raw_string_ostream OS (Fn);
280285 OS << MapNames::getDpctNamespace () << " group::" << HelpFuncName << " <"
281- << ValueTyArg.getAsType ().getAsString () << " , "
282- << ItemsPreThreadArg.getAsIntegral () << " >::get_local_memory_size" ;
286+ << ValueTyArg.getAsType ().getAsString ();
287+ if (isBlockShuffle) {
288+ if (!ClassSpecDecl->getTemplateArgs ()[1 ].getIsDefaulted ()) {
289+ OS << " , " << ClassSpecDecl->getTemplateArgs ()[1 ].getAsIntegral ();
290+ }
291+ if (!ClassSpecDecl->getTemplateArgs ()[2 ].getIsDefaulted ()) {
292+ OS << " , " << ClassSpecDecl->getTemplateArgs ()[2 ].getAsIntegral ();
293+ }
294+ if (!ClassSpecDecl->getTemplateArgs ()[3 ].getIsDefaulted ()) {
295+ OS << " , " << ClassSpecDecl->getTemplateArgs ()[3 ].getAsIntegral ();
296+ }
297+ } else {
298+ const auto &ItemsPreThreadArg = ClassSpecDecl->getTemplateArgs ()[2 ];
299+ OS << " , " << ItemsPreThreadArg.getAsIntegral ();
300+ }
301+ OS << " >::get_local_memory_size" ;
283302 if (auto FuncInfo = DeviceFunctionDecl::LinkRedecls (FD)) {
284303 auto LocInfo = DpctGlobalInfo::getLocInfo (TempStorage);
285304 ExprAnalysis EA;
286305 EA.analyze (DataTypeLoc);
287306 FuncInfo->getVarMap ().addCUBTempStorage (
288307 std::make_shared<TempStorageVarInfo>(
289- LocInfo.second , TempStorageVarInfo::BlockRadixSort,
308+ LocInfo.second ,
309+ isBlockShuffle ? TempStorageVarInfo::BlockShuffle
310+ : TempStorageVarInfo::BlockRadixSort,
290311 TempStorage->getName (), Fn,
291312 EA.getTemplateDependentStringInfo ()));
292313 }
0 commit comments