Skip to content

Commit 7bc59b6

Browse files
authored
[SYCLomatic] Support bgradb epilogue for dpct::experimental::matmul (#2607)
Signed-off-by: Jiang, Zhiwei <zhiwei.jiang@intel.com>
1 parent a92bfbe commit 7bc59b6

1 file changed

Lines changed: 42 additions & 5 deletions

File tree

clang/runtime/dpct-rt/include/dpct/blas_gemm_utils.hpp

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ class matmul_desc_t {
258258
/// scale_type==float && a_type==float && b_type==float && c_type==float.
259259
/// Currently, this function only supports beta==0 or beta==1.
260260
/// Currently, this function only supports the relu, bias, gelu, gelu_bias,
261-
/// gelu_aux, gelu_aux_bias and dgelu epilogue.
261+
/// gelu_aux, gelu_aux_bias, dgelu and bgradb epilogue.
262262
/// NOTE: Non-col-major matrix will be converted to col-major matrix before.
263263
/// TODO: Impl row-major matmul without layout conversion.
264264
/// multiplication and converted back after multiplication.
@@ -331,10 +331,12 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
331331
compute_desc->_epilogue != epilogue_t::gelu_bias &&
332332
compute_desc->_epilogue != epilogue_t::gelu_aux &&
333333
compute_desc->_epilogue != epilogue_t::gelu_aux_bias &&
334-
compute_desc->_epilogue != epilogue_t::dgelu) {
335-
throw std::runtime_error("dpct::blas_gemm::experimental::matmul() only "
336-
"supports relu, bias, gelu, gelu_bias, gelu_aux, "
337-
"gelu_aux_bias and dgelu epilogue currently.");
334+
compute_desc->_epilogue != epilogue_t::dgelu &&
335+
compute_desc->_epilogue != epilogue_t::bgradb) {
336+
throw std::runtime_error(
337+
"dpct::blas_gemm::experimental::matmul() only "
338+
"supports relu, bias, gelu, gelu_bias, gelu_aux, "
339+
"gelu_aux_bias, dgelu and bgradb epilogue currently.");
338340
}
339341

340342
if (!(compute_desc->_scale_type == library_data_t::real_int32 &&
@@ -559,6 +561,28 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
559561
#endif
560562
}
561563

564+
::dnnl::memory *po_bias_bgradb_mem = nullptr;
565+
auto po_bias_bgradb_md = ::dnnl::memory::desc(
566+
compute_desc->_trans_b == oneapi::mkl::transpose::nontrans
567+
? ::dnnl::memory::dims{N, 1}
568+
: ::dnnl::memory::dims{1, N},
569+
dpct::dnnl::memory_desc_ext::to_dnnl_data_type(
570+
compute_desc->_bias_data_type),
571+
compute_desc->_trans_b == oneapi::mkl::transpose::nontrans
572+
? ::dnnl::memory::dims{1, N}
573+
: ::dnnl::memory::dims{N, 1});
574+
if (compute_desc->_epilogue == epilogue_t::bgradb) {
575+
po_bias_bgradb_mem = new ::dnnl::memory(
576+
po_bias_bgradb_md, handle->get_engine(), DNNL_MEMORY_NONE);
577+
#ifdef DPCT_USM_LEVEL_NONE
578+
detail::type_dispatch<detail::set_buffer_impl>(
579+
compute_desc->_bias_data_type, po_bias_bgradb_mem,
580+
compute_desc->_bias_pointer);
581+
#else
582+
po_bias_bgradb_mem->set_data_handle(compute_desc->_bias_pointer);
583+
#endif
584+
}
585+
562586
::dnnl::memory *po_aux_mem = nullptr;
563587
auto po_aux_md = ::dnnl::memory::desc(
564588
::dnnl::memory::dims{M, N},
@@ -660,6 +684,17 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
660684
post_op_prim_event =
661685
::dnnl::sycl_interop::execute(dgelu_prim, handle->get_engine_stream(),
662686
dgelu_args, {matmul_prim_event});
687+
} else if (compute_desc->_epilogue == epilogue_t::bgradb) {
688+
auto reduction_pd = ::dnnl::reduction::primitive_desc(
689+
handle->get_engine(), ::dnnl::algorithm::reduction_sum, weights_md,
690+
po_bias_bgradb_md, 0.f, 0.f);
691+
auto reduction_prim = ::dnnl::reduction(reduction_pd);
692+
std::unordered_map<int, ::dnnl::memory> reduction_args;
693+
reduction_args.insert({DNNL_ARG_SRC, *weights_mem});
694+
reduction_args.insert({DNNL_ARG_DST, *po_bias_bgradb_mem});
695+
post_op_prim_event = ::dnnl::sycl_interop::execute(
696+
reduction_prim, handle->get_engine_stream(), reduction_args,
697+
{matmul_prim_event});
663698
}
664699

665700
// end of calling oneDNN
@@ -700,6 +735,8 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
700735
delete dst_mem;
701736
if (po_bias_mem)
702737
delete po_bias_mem;
738+
if (po_bias_bgradb_mem)
739+
delete po_bias_bgradb_mem;
703740
if (po_aux_mem)
704741
delete po_aux_mem;
705742
::dpct::cs::free((void *)new_a, *q_ptr);

0 commit comments

Comments
 (0)