@@ -634,7 +634,6 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
634634 json error = nullptr ;
635635
636636 server_task_result_ptr result = ctx_server->queue_results .recv (id_task);
637- ctx_server->queue_results .remove_waiting_task_id (id_task);
638637
639638 json response_str = result->to_json ();
640639 if (result->is_error ()) {
@@ -644,6 +643,10 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
644643 return nullptr ;
645644 }
646645
646+ if (result->is_stop ()) {
647+ ctx_server->queue_results .remove_waiting_task_id (id_task);
648+ }
649+
647650 const auto out_res = result->to_json ();
648651
649652 // Extract "embedding" as a vector of vectors (2D array)
@@ -679,6 +682,102 @@ JNIEXPORT jfloatArray JNICALL Java_de_kherud_llama_LlamaModel_embed(JNIEnv *env,
679682 return j_embedding;
680683}
681684
685+ JNIEXPORT jobject JNICALL Java_de_kherud_llama_LlamaModel_rerank (JNIEnv *env, jobject obj, jstring jprompt,
686+ jobjectArray documents) {
687+ jlong server_handle = env->GetLongField (obj, f_model_pointer);
688+ auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
689+
690+ if (!ctx_server->params_base .reranking || ctx_server->params_base .embedding ) {
691+ env->ThrowNew (c_llama_error,
692+ " This server does not support reranking. Start it with `--reranking` and without `--embedding`" );
693+ return nullptr ;
694+ }
695+
696+ const std::string prompt = parse_jstring (env, jprompt);
697+
698+ const auto tokenized_query = tokenize_mixed (ctx_server->vocab , prompt, true , true );
699+
700+ json responses = json::array ();
701+
702+ std::vector<server_task> tasks;
703+ const jsize amount_documents = env->GetArrayLength (documents);
704+ auto *document_array = parse_string_array (env, documents, amount_documents);
705+ auto document_vector = std::vector<std::string>(document_array, document_array + amount_documents);
706+ free_string_array (document_array, amount_documents);
707+
708+ std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts (ctx_server->vocab , document_vector, true , true );
709+
710+ tasks.reserve (tokenized_docs.size ());
711+ for (int i = 0 ; i < tokenized_docs.size (); i++) {
712+ auto task = server_task (SERVER_TASK_TYPE_RERANK);
713+ task.id = ctx_server->queue_tasks .get_new_id ();
714+ task.index = i;
715+ task.prompt_tokens = format_rerank (ctx_server->vocab , tokenized_query, tokenized_docs[i]);
716+ tasks.push_back (task);
717+ }
718+ ctx_server->queue_results .add_waiting_tasks (tasks);
719+ ctx_server->queue_tasks .post (tasks);
720+
721+ // get the result
722+ std::unordered_set<int > task_ids = server_task::get_list_id (tasks);
723+ std::vector<server_task_result_ptr> results (task_ids.size ());
724+
725+ // Create a new HashMap instance
726+ jobject o_probabilities = env->NewObject (c_hash_map, cc_hash_map);
727+ if (o_probabilities == nullptr ) {
728+ env->ThrowNew (c_llama_error, " Failed to create HashMap object." );
729+ return nullptr ;
730+ }
731+
732+ for (int i = 0 ; i < (int )task_ids.size (); i++) {
733+ server_task_result_ptr result = ctx_server->queue_results .recv (task_ids);
734+ if (result->is_error ()) {
735+ auto response = result->to_json ()[" message" ].get <std::string>();
736+ for (const int id_task : task_ids) {
737+ ctx_server->queue_results .remove_waiting_task_id (id_task);
738+ }
739+ env->ThrowNew (c_llama_error, response.c_str ());
740+ return nullptr ;
741+ }
742+
743+ const auto out_res = result->to_json ();
744+
745+ if (result->is_stop ()) {
746+ for (const int id_task : task_ids) {
747+ ctx_server->queue_results .remove_waiting_task_id (id_task);
748+ }
749+ }
750+
751+ int index = out_res[" index" ].get <int >();
752+ float score = out_res[" score" ].get <float >();
753+ std::string tok_str = document_vector[index];
754+ jstring jtok_str = env->NewStringUTF (tok_str.c_str ());
755+
756+ jobject jprob = env->NewObject (c_float, cc_float, score);
757+ env->CallObjectMethod (o_probabilities, m_map_put, jtok_str, jprob);
758+ env->DeleteLocalRef (jtok_str);
759+ env->DeleteLocalRef (jprob);
760+ }
761+ jbyteArray jbytes = parse_jbytes (env, prompt);
762+ return env->NewObject (c_output, cc_output, jbytes, o_probabilities, true );
763+ }
764+
765+ JNIEXPORT jstring JNICALL Java_de_kherud_llama_LlamaModel_applyTemplate (JNIEnv *env, jobject obj, jstring jparams) {
766+ jlong server_handle = env->GetLongField (obj, f_model_pointer);
767+ auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
768+
769+ std::string c_params = parse_jstring (env, jparams);
770+ json data = json::parse (c_params);
771+
772+ json templateData =
773+ oaicompat_completion_params_parse (data, ctx_server->params_base .use_jinja ,
774+ ctx_server->params_base .reasoning_format , ctx_server->chat_templates .get ());
775+ std::string tok_str = templateData.at (" prompt" );
776+ jstring jtok_str = env->NewStringUTF (tok_str.c_str ());
777+
778+ return jtok_str;
779+ }
780+
682781JNIEXPORT jintArray JNICALL Java_de_kherud_llama_LlamaModel_encode (JNIEnv *env, jobject obj, jstring jprompt) {
683782 jlong server_handle = env->GetLongField (obj, f_model_pointer);
684783 auto *ctx_server = reinterpret_cast <server_context *>(server_handle); // NOLINT(*-no-int-to-ptr)
0 commit comments