Skip to content

Commit 821800f

Browse files
authored
[SYCLomatic][PTX] Refine migration of PTX asm instruction "lop3.b32" (#2592)
Signed-off-by: chenwei.sun <chenwei.sun@intel.com>
1 parent 7e8a364 commit 821800f

3 files changed

Lines changed: 140 additions & 23 deletions

File tree

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -943,24 +943,27 @@ class SYCLGen : public SYCLGenBase {
943943
return SYCLGenError();
944944
OS() << " = ";
945945

946-
std::string Op[3];
947-
for (auto Idx : llvm::seq(0, 3)) {
946+
std::string Op[4];
947+
for (auto Idx : llvm::seq(0, 4)) {
948948
if (tryEmitStmt(Op[Idx], I->getInputOperand(Idx)))
949949
return SYCLGenError();
950950
}
951951

952-
if (!isa<InlineAsmIntegerLiteral>(I->getInputOperand(3)))
953-
return SYCLGenError();
954-
unsigned Imm = dyn_cast<InlineAsmIntegerLiteral>(I->getInputOperand(3))
955-
->getValue()
956-
.getZExtValue();
952+
if (!isa<InlineAsmIntegerLiteral>(I->getInputOperand(3))) {
953+
OS() << MapNames::getDpctNamespace() << "ternary_logic_op(" << Op[0]
954+
<< ", " << Op[1] << ", " << Op[2] << ", " << Op[3] << ")";
955+
956+
} else {
957+
unsigned Imm = dyn_cast<InlineAsmIntegerLiteral>(I->getInputOperand(3))
958+
->getValue()
959+
.getZExtValue();
957960

958961
#define EMPTY nullptr
959962
#define EMPTY4 EMPTY, EMPTY, EMPTY, EMPTY
960963
#define EMPTY16 EMPTY4, EMPTY4, EMPTY4, EMPTY4
961-
constexpr const char *FastMap[256] = {
962-
/*0x00*/ "0",
963-
// clang-format off
964+
constexpr const char *FastMap[256] = {
965+
/*0x00*/ "0",
966+
// clang-format off
964967
EMPTY16, EMPTY4, EMPTY4, EMPTY,
965968
/*0x1a*/ "({0} & {1} | {2}) ^ {0}",
966969
EMPTY, EMPTY, EMPTY,
@@ -988,12 +991,12 @@ class SYCLGen : public SYCLGenBase {
988991
EMPTY16, EMPTY, EMPTY, EMPTY,
989992
/*0xfe*/ "{0} | {1} | {2}",
990993
/*0xff*/ "uint32_t(-1)"};
991-
// clang-format on
994+
// clang-format on
992995

993996
#undef EMPTY16
994997
#undef EMPTY4
995998
#undef EMPTY
996-
// clang-format off
999+
// clang-format off
9971000
constexpr const char *SlowMap[8] = {
9981001
/* 0x01*/ "(~{0} & ~{1} & ~{2})",
9991002
/* 0x02*/ "(~{0} & ~{1} & {2})",
@@ -1004,20 +1007,21 @@ class SYCLGen : public SYCLGenBase {
10041007
/* 0x40*/ "({0} & {1} & ~{2})",
10051008
/* 0x80*/ "({0} & {1} & {2})"
10061009
};
1007-
// clang-format on
1010+
// clang-format on
10081011

1009-
if (FastMap[Imm]) {
1010-
OS() << llvm::formatv(FastMap[Imm], Op[0], Op[1], Op[2]);
1011-
} else {
1012-
SmallVector<std::string, 8> Templates;
1013-
for (auto Bit : llvm::seq(0, 8)) {
1014-
if (Imm & (1U << Bit)) {
1015-
Templates.push_back(
1016-
llvm::formatv(SlowMap[Bit], Op[0], Op[1], Op[2]).str());
1012+
if (FastMap[Imm]) {
1013+
OS() << llvm::formatv(FastMap[Imm], Op[0], Op[1], Op[2]);
1014+
} else {
1015+
SmallVector<std::string, 8> Templates;
1016+
for (auto Bit : llvm::seq(0, 8)) {
1017+
if (Imm & (1U << Bit)) {
1018+
Templates.push_back(
1019+
llvm::formatv(SlowMap[Bit], Op[0], Op[1], Op[2]).str());
1020+
}
10171021
}
1018-
}
10191022

1020-
OS() << llvm::join(Templates, " | ");
1023+
OS() << llvm::join(Templates, " | ");
1024+
}
10211025
}
10221026

10231027
endstmt();

clang/runtime/dpct-rt/include/dpct/util.hpp

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,106 @@ template <typename Func, std::size_t N> struct nth_argument_type {
12051205
using type = decltype(helper(std::declval<Func>()));
12061206
};
12071207

1208+
/// \brief The function performs bitwise logical operations on three input
1209+
/// values of \p a, \p b and \p c based on the specified 8-bit truth table \p
1210+
/// lut and return the result
1211+
///
1212+
/// \param [in] a Input value
1213+
/// \param [in] b Input value
1214+
/// \param [in] c Input value
1215+
/// \param [in] lut truth table for looking up
1216+
/// \returns The result
1217+
inline uint32_t ternary_logic_op(uint32_t a, uint32_t b, uint32_t c,
1218+
uint8_t lut) {
1219+
uint32_t result = 0;
1220+
1221+
switch (lut) {
1222+
case 0x0:
1223+
result = 0;
1224+
break;
1225+
case 0x1:
1226+
result = ~a & ~b & ~c;
1227+
break;
1228+
case 0x2:
1229+
result = ~a & ~b & c;
1230+
case 0x4:
1231+
result = ~a & b & ~c;
1232+
break;
1233+
case 0x8:
1234+
result = ~a & b & c;
1235+
break;
1236+
case 0x10:
1237+
result = a & ~b & ~c;
1238+
break;
1239+
case 0x20:
1240+
result = a & ~b & c;
1241+
break;
1242+
case 0x40:
1243+
result = a & b & ~c;
1244+
break;
1245+
case 0x80:
1246+
result = a & b & c;
1247+
break;
1248+
case 0x1a:
1249+
result = (a & b | c) ^ a;
1250+
break;
1251+
case 0x1e:
1252+
result = a ^ (b | c);
1253+
break;
1254+
case 0x2d:
1255+
result = ~a ^ (~b & c);
1256+
break;
1257+
case 0x78:
1258+
result = a ^ (b & c);
1259+
break;
1260+
case 0x96:
1261+
result = a ^ b ^ c;
1262+
break;
1263+
case 0xb4:
1264+
result = a ^ (b & ~c);
1265+
break;
1266+
case 0xb8:
1267+
result = a ^ (b & (c ^ a));
1268+
break;
1269+
case 0xd2:
1270+
result = a ^ (~b & c);
1271+
break;
1272+
case 0xe8:
1273+
result = a & (b | c) | (b & c);
1274+
break;
1275+
case 0xea:
1276+
result = a & b | c;
1277+
break;
1278+
case 0xfe:
1279+
result = a | b | c;
1280+
break;
1281+
case 0xff:
1282+
result = -1;
1283+
break;
1284+
default: {
1285+
if (lut & 0x01)
1286+
result |= ~a & ~b & ~c;
1287+
if (lut & 0x02)
1288+
result |= ~a & ~b & c;
1289+
if (lut & 0x04)
1290+
result |= ~a & b & ~c;
1291+
if (lut & 0x08)
1292+
result |= ~a & b & c;
1293+
if (lut & 0x10)
1294+
result |= a & ~b & ~c;
1295+
if (lut & 0x20)
1296+
result |= a & ~b & c;
1297+
if (lut & 0x40)
1298+
result |= a & b & ~c;
1299+
if (lut & 0x80)
1300+
result |= a & b & c;
1301+
break;
1302+
}
1303+
}
1304+
1305+
return result;
1306+
}
1307+
12081308
#ifdef _WIN32
12091309
#define DPCT_EXPORT __declspec(dllexport)
12101310
#else

clang/test/dpct/asm/lop3.cu

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,17 @@ __device__ int hard(int a) {
3535
asm("lop3.b32 %0, %1, %2, 3, 0x1C;" : "=r"(d4) : "r"(a + B), "r"(B));
3636
return d4;
3737
}
38+
39+
// CHECK: template <int lut, typename T> inline T lop3(T a, T b, T c) {
40+
// CHECK-NEXT: T res;
41+
// CHECK-NEXT: res = dpct::ternary_logic_op(a, b, c, lut);
42+
// CHECK-NEXT: return res;
43+
// CHECK-NEXT:}
44+
template <int lut, typename T> __device__ inline T lop3(T a, T b, T c) {
45+
T res;
46+
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n"
47+
: "=r"(res)
48+
: "r"(a), "r"(b), "r"(c), "n"(lut));
49+
return res;
50+
}
3851
// clang-format on

0 commit comments

Comments
 (0)