diff --git a/README.md b/README.md index d2e8781b..0f38c1a6 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ async fn main() -> anyhow::Result<()> { } ``` -The generated tool `inputSchema` is derived from the fields of `T`. The type name and documentation on `T` are ignored; only field names, field types, and field documentation are used. +The generated tool `inputSchema` and `outputSchema` are derived from the fields of `T`. The type name and documentation on `T` are ignored; only field names, field types, and field documentation are used. When you need custom server metadata or multiple capabilities (tools + prompts), use explicit `#[tool_handler]`: diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index b4b6c0b9..5e5044eb 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -232,6 +232,13 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { // if found, use the Parameters schema syn::parse2::(quote! { rmcp::handler::server::common::schema_for_input::<#params_ty>() + .unwrap_or_else(|e| { + panic!( + "Invalid input schema for `{}`: {}", + std::any::type_name::<#params_ty>(), + e + ) + }) })? } else { // if not found, use a default empty JSON schema object diff --git a/crates/rmcp/src/handler/server/common.rs b/crates/rmcp/src/handler/server/common.rs index 153e33c6..aa996b5b 100644 --- a/crates/rmcp/src/handler/server/common.rs +++ b/crates/rmcp/src/handler/server/common.rs @@ -1,6 +1,10 @@ //! Common utilities shared between tool and prompt handlers -use std::{any::TypeId, collections::HashMap, sync::Arc}; +use std::{ + any::TypeId, + collections::HashMap, + sync::{Arc, LazyLock}, +}; use schemars::JsonSchema; @@ -30,12 +34,10 @@ pub fn schema_for_type() -> Arc { let generator = settings.into_generator(); let schema = generator.into_root_schema_for::(); let object = serde_json::to_value(schema).expect("failed to serialize schema"); - let object = match object { - serde_json::Value::Object(object) => object, - _ => panic!( - "Schema serialization produced non-object value: expected JSON object but got {:?}", - object - ), + let serde_json::Value::Object(object) = object else { + panic!( + "Schema serialization produced non-object value: expected JSON object but got {object:?}" + ); }; let schema = Arc::new(object); cache @@ -48,51 +50,63 @@ pub fn schema_for_type() -> Arc { }) } -/// Generate a JSON schema for inputSchema (does not need "title" or "description" fields for the top-level object) -pub fn schema_for_input() -> Arc { +/// Validate that the schema root is `type: "object"` (per MCP spec) and strip top-level +/// `title`/`description` (the wrapper type name and doc, which are noise to the LLM). +fn validate_and_strip(raw: &Arc, purpose: &str) -> Result, String> { + match raw.get("type") { + Some(serde_json::Value::String(t)) if t == "object" => { + let mut object = raw.as_ref().clone(); + object.remove("title"); + object.remove("description"); + Ok(Arc::new(object)) + } + Some(serde_json::Value::String(t)) => Err(format!( + "MCP specification requires tool {purpose} to have root type 'object', but found '{t}'." + )), + None => Err(format!( + "Schema is missing 'type' field. MCP specification requires {purpose} to have root type 'object'." + )), + Some(other) => Err(format!( + "Schema 'type' field has unexpected format: {other:?}. Expected \"object\"." + )), + } +} + +/// Generate, validate, and strip a JSON schema for inputSchema (must have root type "object"; +/// top-level "title" and "description" are removed). +pub fn schema_for_input() -> Result, String> { thread_local! { - static CACHE_FOR_INPUT: std::sync::RwLock>> = Default::default(); + static CACHE_FOR_INPUT: std::sync::RwLock, String>>> = Default::default(); }; CACHE_FOR_INPUT.with(|cache| { - if let Some(schema) = cache + if let Some(result) = cache .read() .expect("input schema cache lock poisoned") .get(&TypeId::of::()) { - schema.clone() - } else { - let mut schema = schema_for_type::().as_ref().clone(); - - // Remove unnecessary top-level fields - schema.remove("title"); - schema.remove("description"); - - let schema = Arc::new(schema); - cache - .write() - .expect("input schema cache lock poisoned") - .insert(TypeId::of::(), schema.clone()); - - schema + return result.clone(); } + let result = validate_and_strip(&schema_for_type::(), "inputSchema"); + cache + .write() + .expect("input schema cache lock poisoned") + .insert(TypeId::of::(), result.clone()); + result }) } -// TODO: should be updated according to the new specifications /// Schema used when input is empty. pub fn schema_for_empty_input() -> Arc { - std::sync::Arc::new( - serde_json::json!({ - "type": "object", - "properties": {} - }) - .as_object() - .unwrap() - .clone(), - ) + static EMPTY: LazyLock> = LazyLock::new(|| { + let mut object = JsonObject::new(); + object.insert("type".into(), serde_json::json!("object")); + object.insert("properties".into(), serde_json::json!({})); + Arc::new(object) + }); + EMPTY.clone() } -/// Generate and validate a JSON schema for outputSchema (must have root type "object"). +/// Generate a JSON schema for outputSchema (must have root type "object"; top-level "title" and "description" are removed) pub fn schema_for_output() -> Result, String> { thread_local! { static CACHE_FOR_OUTPUT: std::sync::RwLock, String>>> = Default::default(); @@ -108,22 +122,8 @@ pub fn schema_for_output() -> Result(); - let result = match schema.get("type") { - Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()), - Some(serde_json::Value::String(t)) => Err(format!( - "MCP specification requires tool outputSchema to have root type 'object', but found '{}'.", - t - )), - None => Err( - "Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string() - ), - Some(other) => Err(format!( - "Schema 'type' field has unexpected format: {:?}. Expected \"object\".", - other - )), - }; + // Generate, validate, and strip unnecessary top-level fields + let result = validate_and_strip(&schema_for_type::(), "outputSchema"); // Cache the result (both success and error cases) cache @@ -316,4 +316,40 @@ mod tests { let result = schema_for_output::(); assert!(result.is_ok(),); } + + #[test] + fn test_schema_for_output_strips_top_level_title() { + let schema = schema_for_output::().unwrap(); + assert!(!schema.contains_key("title")); + } + + #[test] + fn test_schema_for_output_strips_top_level_description() { + let schema = schema_for_output::().unwrap(); + assert!(!schema.contains_key("description")); + } + + #[test] + fn test_schema_for_input_rejects_primitive() { + let result = schema_for_input::(); + assert!(result.is_err()); + } + + #[test] + fn test_schema_for_input_accepts_object() { + let result = schema_for_input::(); + assert!(result.is_ok()); + } + + #[test] + fn test_schema_for_input_strips_top_level_title() { + let schema = schema_for_input::().unwrap(); + assert!(!schema.contains_key("title")); + } + + #[test] + fn test_schema_for_input_strips_top_level_description() { + let schema = schema_for_input::().unwrap(); + assert!(!schema.contains_key("description")); + } } diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index 8e291327..c2bf299c 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -250,7 +250,9 @@ where attr: Tool::new( name.into(), "", - schema_for_input::(), + schema_for_input::().unwrap_or_else(|e| { + panic!("Invalid input schema for JsonObject: {e}"); + }), ), call: self, _marker: std::marker::PhantomData, @@ -287,7 +289,12 @@ where self } pub fn parameters(mut self) -> Self { - self.attr.input_schema = schema_for_input::(); + self.attr.input_schema = schema_for_input::().unwrap_or_else(|e| { + panic!( + "Invalid input schema for `{}`: {e}", + std::any::type_name::() + ) + }); self } pub fn parameters_value(mut self, schema: serde_json::Value) -> Self { diff --git a/crates/rmcp/src/handler/server/router/tool/tool_traits.rs b/crates/rmcp/src/handler/server/router/tool/tool_traits.rs index 57977db3..b0bf9e2d 100644 --- a/crates/rmcp/src/handler/server/router/tool/tool_traits.rs +++ b/crates/rmcp/src/handler/server/router/tool/tool_traits.rs @@ -49,7 +49,14 @@ pub trait ToolBase { /// If the tool does not have any parameters, you should override this methods to return [`None`], /// and when invoked, the parameter will get default values. fn input_schema() -> Option> { - Some(schema_for_input::>()) + Some( + schema_for_input::>().unwrap_or_else(|e| { + panic!( + "Invalid input schema for ToolBase::Parameter type `{0}`: {e}", + std::any::type_name::(), + ); + }), + ) } /// Json schema for tool output. diff --git a/crates/rmcp/src/model/tool.rs b/crates/rmcp/src/model/tool.rs index 11bba529..2ed89d6a 100644 --- a/crates/rmcp/src/model/tool.rs +++ b/crates/rmcp/src/model/tool.rs @@ -330,7 +330,8 @@ impl Tool { /// Set the input schema using a type that implements JsonSchema #[cfg(feature = "server")] pub fn with_input_schema(mut self) -> Self { - self.input_schema = crate::handler::server::tool::schema_for_input::(); + self.input_schema = crate::handler::server::tool::schema_for_input::() + .unwrap_or_else(|e| panic!("Invalid input schema for tool '{}': {}", self.name, e)); self } diff --git a/crates/rmcp/tests/test_list_tools_result.rs b/crates/rmcp/tests/test_list_tools_result.rs index 1736b0ee..4d8dc70a 100644 --- a/crates/rmcp/tests/test_list_tools_result.rs +++ b/crates/rmcp/tests/test_list_tools_result.rs @@ -1,6 +1,7 @@ #![cfg(all(feature = "server", feature = "macros", not(feature = "local")))] use rmcp::{ + Json, handler::server::wrapper::Parameters, model::{ListToolsResult, NumberOrString, ServerJsonRpcMessage, ServerResult}, }; @@ -14,10 +15,17 @@ struct AddRequest { b: f64, } +/// Result of adding two numbers. +#[derive(Debug, serde::Serialize, schemars::JsonSchema)] +struct AddResult { + /// The sum of the two numbers. + sum: f64, +} + /// Add two numbers. #[rmcp::tool] -fn add(Parameters(AddRequest { a, b }): Parameters) -> String { - (a + b).to_string() +fn add(Parameters(AddRequest { a, b }): Parameters) -> Json { + Json(AddResult { sum: a + b }) } #[test] @@ -27,7 +35,7 @@ fn list_tools_result_matches_expected_json() { let expected: serde_json::Value = serde_json::from_slice(&expected_json).expect("invalid expected JSON fixture"); - assert_eq!(add(Parameters(AddRequest { a: 1.0, b: 2.0 })), "3"); + assert_eq!(add(Parameters(AddRequest { a: 1.0, b: 2.0 })).0.sum, 3.0); let result = ListToolsResult::with_all_items(vec![add_tool_attr()]); let response = ServerJsonRpcMessage::response( diff --git a/crates/rmcp/tests/test_list_tools_result/list_tools_result.json b/crates/rmcp/tests/test_list_tools_result/list_tools_result.json index 1ef88230..15325e8f 100644 --- a/crates/rmcp/tests/test_list_tools_result/list_tools_result.json +++ b/crates/rmcp/tests/test_list_tools_result/list_tools_result.json @@ -23,6 +23,20 @@ "a", "b" ] + }, + "outputSchema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "sum": { + "description": "The sum of the two numbers.", + "format": "double", + "type": "number" + } + }, + "required": [ + "sum" + ] } } ] diff --git a/crates/rmcp/tests/test_structured_output.rs b/crates/rmcp/tests/test_structured_output.rs index 7bb62e65..adbdfec5 100644 --- a/crates/rmcp/tests/test_structured_output.rs +++ b/crates/rmcp/tests/test_structured_output.rs @@ -28,6 +28,16 @@ pub struct UserInfo { pub age: u32, } +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct GreetingRequest { + pub name: String, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct GetUserRequest { + pub user_id: String, +} + #[tool_handler(router = self.tool_router)] impl ServerHandler for TestServer {} @@ -64,14 +74,17 @@ impl TestServer { /// Tool that returns regular string output #[tool(name = "get-greeting", description = "Get a greeting")] - pub async fn get_greeting(&self, name: Parameters) -> String { - format!("Hello, {}!", name.0) + pub async fn get_greeting(&self, params: Parameters) -> String { + format!("Hello, {}!", params.0.name) } /// Tool that returns structured user info #[tool(name = "get-user", description = "Get user info")] - pub async fn get_user(&self, user_id: Parameters) -> Result, String> { - if user_id.0 == "123" { + pub async fn get_user( + &self, + params: Parameters, + ) -> Result, String> { + if params.0.user_id == "123" { Ok(Json(UserInfo { name: "Alice".to_string(), age: 30,