@@ -258,7 +258,7 @@ class matmul_desc_t {
258258// / scale_type==float && a_type==float && b_type==float && c_type==float.
259259// / Currently, this function only supports beta==0 or beta==1.
260260// / Currently, this function only supports the relu, bias, gelu, gelu_bias,
261- // / gelu_aux, gelu_aux_bias and dgelu epilogue.
261+ // / gelu_aux, gelu_aux_bias, dgelu and bgradb epilogue.
262262// / NOTE: Non-col-major matrix will be converted to col-major matrix before.
263263// / TODO: Impl row-major matmul without layout conversion.
264264// / multiplication and converted back after multiplication.
@@ -331,10 +331,12 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
331331 compute_desc->_epilogue != epilogue_t ::gelu_bias &&
332332 compute_desc->_epilogue != epilogue_t ::gelu_aux &&
333333 compute_desc->_epilogue != epilogue_t ::gelu_aux_bias &&
334- compute_desc->_epilogue != epilogue_t ::dgelu) {
335- throw std::runtime_error (" dpct::blas_gemm::experimental::matmul() only "
336- " supports relu, bias, gelu, gelu_bias, gelu_aux, "
337- " gelu_aux_bias and dgelu epilogue currently." );
334+ compute_desc->_epilogue != epilogue_t ::dgelu &&
335+ compute_desc->_epilogue != epilogue_t ::bgradb) {
336+ throw std::runtime_error (
337+ " dpct::blas_gemm::experimental::matmul() only "
338+ " supports relu, bias, gelu, gelu_bias, gelu_aux, "
339+ " gelu_aux_bias, dgelu and bgradb epilogue currently." );
338340 }
339341
340342 if (!(compute_desc->_scale_type == library_data_t ::real_int32 &&
@@ -559,6 +561,28 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
559561#endif
560562 }
561563
564+ ::dnnl::memory *po_bias_bgradb_mem = nullptr ;
565+ auto po_bias_bgradb_md = ::dnnl::memory::desc (
566+ compute_desc->_trans_b == oneapi::mkl::transpose::nontrans
567+ ? ::dnnl::memory::dims{N, 1 }
568+ : ::dnnl::memory::dims{1 , N},
569+ dpct::dnnl::memory_desc_ext::to_dnnl_data_type (
570+ compute_desc->_bias_data_type ),
571+ compute_desc->_trans_b == oneapi::mkl::transpose::nontrans
572+ ? ::dnnl::memory::dims{1 , N}
573+ : ::dnnl::memory::dims{N, 1 });
574+ if (compute_desc->_epilogue == epilogue_t ::bgradb) {
575+ po_bias_bgradb_mem = new ::dnnl::memory (
576+ po_bias_bgradb_md, handle->get_engine (), DNNL_MEMORY_NONE);
577+ #ifdef DPCT_USM_LEVEL_NONE
578+ detail::type_dispatch<detail::set_buffer_impl>(
579+ compute_desc->_bias_data_type , po_bias_bgradb_mem,
580+ compute_desc->_bias_pointer );
581+ #else
582+ po_bias_bgradb_mem->set_data_handle (compute_desc->_bias_pointer );
583+ #endif
584+ }
585+
562586 ::dnnl::memory *po_aux_mem = nullptr ;
563587 auto po_aux_md = ::dnnl::memory::desc (
564588 ::dnnl::memory::dims{M, N},
@@ -660,6 +684,17 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
660684 post_op_prim_event =
661685 ::dnnl::sycl_interop::execute (dgelu_prim, handle->get_engine_stream (),
662686 dgelu_args, {matmul_prim_event});
687+ } else if (compute_desc->_epilogue == epilogue_t ::bgradb) {
688+ auto reduction_pd = ::dnnl::reduction::primitive_desc (
689+ handle->get_engine (), ::dnnl::algorithm::reduction_sum, weights_md,
690+ po_bias_bgradb_md, 0 .f , 0 .f );
691+ auto reduction_prim = ::dnnl::reduction (reduction_pd);
692+ std::unordered_map<int , ::dnnl::memory> reduction_args;
693+ reduction_args.insert ({DNNL_ARG_SRC, *weights_mem});
694+ reduction_args.insert ({DNNL_ARG_DST, *po_bias_bgradb_mem});
695+ post_op_prim_event = ::dnnl::sycl_interop::execute (
696+ reduction_prim, handle->get_engine_stream (), reduction_args,
697+ {matmul_prim_event});
663698 }
664699
665700 // end of calling oneDNN
@@ -700,6 +735,8 @@ inline sycl::event matmul(descriptor_ptr handle, matmul_desc_ptr compute_desc,
700735 delete dst_mem;
701736 if (po_bias_mem)
702737 delete po_bias_mem;
738+ if (po_bias_bgradb_mem)
739+ delete po_bias_bgradb_mem;
703740 if (po_aux_mem)
704741 delete po_aux_mem;
705742 ::dpct::cs::free ((void *)new_a, *q_ptr);
0 commit comments