From 7018860386683738c02562fd540b56de218d2443 Mon Sep 17 00:00:00 2001 From: Dale Seo <5466341+DaleSeo@users.noreply.github.com> Date: Thu, 28 May 2026 16:26:39 -0400 Subject: [PATCH 1/2] fix: remove unnecessary fields from tools' outputSchema --- README.md | 2 +- crates/rmcp/src/handler/server/common.rs | 82 +++++++++++-------- crates/rmcp/tests/test_list_tools_result.rs | 14 +++- .../list_tools_result.json | 14 ++++ 4 files changed, 74 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index d2e8781b..0f38c1a6 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ async fn main() -> anyhow::Result<()> { } ``` -The generated tool `inputSchema` is derived from the fields of `T`. The type name and documentation on `T` are ignored; only field names, field types, and field documentation are used. +The generated tool `inputSchema` and `outputSchema` are derived from the fields of `T`. The type name and documentation on `T` are ignored; only field names, field types, and field documentation are used. When you need custom server metadata or multiple capabilities (tools + prompts), use explicit `#[tool_handler]`: diff --git a/crates/rmcp/src/handler/server/common.rs b/crates/rmcp/src/handler/server/common.rs index 153e33c6..69bbc803 100644 --- a/crates/rmcp/src/handler/server/common.rs +++ b/crates/rmcp/src/handler/server/common.rs @@ -1,6 +1,10 @@ //! Common utilities shared between tool and prompt handlers -use std::{any::TypeId, collections::HashMap, sync::Arc}; +use std::{ + any::TypeId, + collections::HashMap, + sync::{Arc, LazyLock}, +}; use schemars::JsonSchema; @@ -30,12 +34,10 @@ pub fn schema_for_type() -> Arc { let generator = settings.into_generator(); let schema = generator.into_root_schema_for::(); let object = serde_json::to_value(schema).expect("failed to serialize schema"); - let object = match object { - serde_json::Value::Object(object) => object, - _ => panic!( - "Schema serialization produced non-object value: expected JSON object but got {:?}", - object - ), + let serde_json::Value::Object(object) = object else { + panic!( + "Schema serialization produced non-object value: expected JSON object but got {object:?}" + ); }; let schema = Arc::new(object); cache @@ -48,7 +50,16 @@ pub fn schema_for_type() -> Arc { }) } -/// Generate a JSON schema for inputSchema (does not need "title" or "description" fields for the top-level object) +/// Clone the schema and remove top-level "title" and "description" fields +/// (the wrapper type name and doc, which are noise to the LLM). +fn strip_top_level_metadata(schema: &Arc) -> Arc { + let mut object = schema.as_ref().clone(); + object.remove("title"); + object.remove("description"); + Arc::new(object) +} + +/// Generate a JSON schema for inputSchema (top-level "title" and "description" are removed) pub fn schema_for_input() -> Arc { thread_local! { static CACHE_FOR_INPUT: std::sync::RwLock>> = Default::default(); @@ -61,13 +72,7 @@ pub fn schema_for_input() -> Arc { { schema.clone() } else { - let mut schema = schema_for_type::().as_ref().clone(); - - // Remove unnecessary top-level fields - schema.remove("title"); - schema.remove("description"); - - let schema = Arc::new(schema); + let schema = strip_top_level_metadata(&schema_for_type::()); cache .write() .expect("input schema cache lock poisoned") @@ -78,21 +83,18 @@ pub fn schema_for_input() -> Arc { }) } -// TODO: should be updated according to the new specifications /// Schema used when input is empty. pub fn schema_for_empty_input() -> Arc { - std::sync::Arc::new( - serde_json::json!({ - "type": "object", - "properties": {} - }) - .as_object() - .unwrap() - .clone(), - ) + static EMPTY: LazyLock> = LazyLock::new(|| { + let mut object = JsonObject::new(); + object.insert("type".into(), serde_json::json!("object")); + object.insert("properties".into(), serde_json::json!({})); + Arc::new(object) + }); + EMPTY.clone() } -/// Generate and validate a JSON schema for outputSchema (must have root type "object"). +/// Generate a JSON schema for outputSchema (must have root type "object"; top-level "title" and "description" are removed) pub fn schema_for_output() -> Result, String> { thread_local! { static CACHE_FOR_OUTPUT: std::sync::RwLock, String>>> = Default::default(); @@ -108,20 +110,20 @@ pub fn schema_for_output() -> Result(); - let result = match schema.get("type") { - Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()), + // Generate, validate, and strip unnecessary top-level fields + let raw = schema_for_type::(); + let result = match raw.get("type") { + Some(serde_json::Value::String(t)) if t == "object" => { + Ok(strip_top_level_metadata(&raw)) + } Some(serde_json::Value::String(t)) => Err(format!( - "MCP specification requires tool outputSchema to have root type 'object', but found '{}'.", - t + "MCP specification requires tool outputSchema to have root type 'object', but found '{t}'." )), None => Err( "Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string() ), Some(other) => Err(format!( - "Schema 'type' field has unexpected format: {:?}. Expected \"object\".", - other + "Schema 'type' field has unexpected format: {other:?}. Expected \"object\"." )), }; @@ -316,4 +318,16 @@ mod tests { let result = schema_for_output::(); assert!(result.is_ok(),); } + + #[test] + fn test_schema_for_output_strips_top_level_title() { + let schema = schema_for_output::().unwrap(); + assert!(!schema.contains_key("title")); + } + + #[test] + fn test_schema_for_output_strips_top_level_description() { + let schema = schema_for_output::().unwrap(); + assert!(!schema.contains_key("description")); + } } diff --git a/crates/rmcp/tests/test_list_tools_result.rs b/crates/rmcp/tests/test_list_tools_result.rs index 1736b0ee..4d8dc70a 100644 --- a/crates/rmcp/tests/test_list_tools_result.rs +++ b/crates/rmcp/tests/test_list_tools_result.rs @@ -1,6 +1,7 @@ #![cfg(all(feature = "server", feature = "macros", not(feature = "local")))] use rmcp::{ + Json, handler::server::wrapper::Parameters, model::{ListToolsResult, NumberOrString, ServerJsonRpcMessage, ServerResult}, }; @@ -14,10 +15,17 @@ struct AddRequest { b: f64, } +/// Result of adding two numbers. +#[derive(Debug, serde::Serialize, schemars::JsonSchema)] +struct AddResult { + /// The sum of the two numbers. + sum: f64, +} + /// Add two numbers. #[rmcp::tool] -fn add(Parameters(AddRequest { a, b }): Parameters) -> String { - (a + b).to_string() +fn add(Parameters(AddRequest { a, b }): Parameters) -> Json { + Json(AddResult { sum: a + b }) } #[test] @@ -27,7 +35,7 @@ fn list_tools_result_matches_expected_json() { let expected: serde_json::Value = serde_json::from_slice(&expected_json).expect("invalid expected JSON fixture"); - assert_eq!(add(Parameters(AddRequest { a: 1.0, b: 2.0 })), "3"); + assert_eq!(add(Parameters(AddRequest { a: 1.0, b: 2.0 })).0.sum, 3.0); let result = ListToolsResult::with_all_items(vec![add_tool_attr()]); let response = ServerJsonRpcMessage::response( diff --git a/crates/rmcp/tests/test_list_tools_result/list_tools_result.json b/crates/rmcp/tests/test_list_tools_result/list_tools_result.json index 1ef88230..15325e8f 100644 --- a/crates/rmcp/tests/test_list_tools_result/list_tools_result.json +++ b/crates/rmcp/tests/test_list_tools_result/list_tools_result.json @@ -23,6 +23,20 @@ "a", "b" ] + }, + "outputSchema": { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "type": "object", + "properties": { + "sum": { + "description": "The sum of the two numbers.", + "format": "double", + "type": "number" + } + }, + "required": [ + "sum" + ] } } ] From d8f81f4a1c297d26042eb73d005277ac8feeb37d Mon Sep 17 00:00:00 2001 From: Dale Seo <5466341+DaleSeo@users.noreply.github.com> Date: Thu, 28 May 2026 17:09:32 -0400 Subject: [PATCH 2/2] fix: validate input schema root type per MCP spec --- crates/rmcp-macros/src/tool.rs | 7 ++ crates/rmcp/src/handler/server/common.rs | 92 ++++++++++++------- crates/rmcp/src/handler/server/router/tool.rs | 11 ++- .../handler/server/router/tool/tool_traits.rs | 9 +- crates/rmcp/src/model/tool.rs | 3 +- crates/rmcp/tests/test_structured_output.rs | 21 ++++- 6 files changed, 100 insertions(+), 43 deletions(-) diff --git a/crates/rmcp-macros/src/tool.rs b/crates/rmcp-macros/src/tool.rs index b4b6c0b9..5e5044eb 100644 --- a/crates/rmcp-macros/src/tool.rs +++ b/crates/rmcp-macros/src/tool.rs @@ -232,6 +232,13 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result { // if found, use the Parameters schema syn::parse2::(quote! { rmcp::handler::server::common::schema_for_input::<#params_ty>() + .unwrap_or_else(|e| { + panic!( + "Invalid input schema for `{}`: {}", + std::any::type_name::<#params_ty>(), + e + ) + }) })? } else { // if not found, use a default empty JSON schema object diff --git a/crates/rmcp/src/handler/server/common.rs b/crates/rmcp/src/handler/server/common.rs index 69bbc803..aa996b5b 100644 --- a/crates/rmcp/src/handler/server/common.rs +++ b/crates/rmcp/src/handler/server/common.rs @@ -50,36 +50,48 @@ pub fn schema_for_type() -> Arc { }) } -/// Clone the schema and remove top-level "title" and "description" fields -/// (the wrapper type name and doc, which are noise to the LLM). -fn strip_top_level_metadata(schema: &Arc) -> Arc { - let mut object = schema.as_ref().clone(); - object.remove("title"); - object.remove("description"); - Arc::new(object) +/// Validate that the schema root is `type: "object"` (per MCP spec) and strip top-level +/// `title`/`description` (the wrapper type name and doc, which are noise to the LLM). +fn validate_and_strip(raw: &Arc, purpose: &str) -> Result, String> { + match raw.get("type") { + Some(serde_json::Value::String(t)) if t == "object" => { + let mut object = raw.as_ref().clone(); + object.remove("title"); + object.remove("description"); + Ok(Arc::new(object)) + } + Some(serde_json::Value::String(t)) => Err(format!( + "MCP specification requires tool {purpose} to have root type 'object', but found '{t}'." + )), + None => Err(format!( + "Schema is missing 'type' field. MCP specification requires {purpose} to have root type 'object'." + )), + Some(other) => Err(format!( + "Schema 'type' field has unexpected format: {other:?}. Expected \"object\"." + )), + } } -/// Generate a JSON schema for inputSchema (top-level "title" and "description" are removed) -pub fn schema_for_input() -> Arc { +/// Generate, validate, and strip a JSON schema for inputSchema (must have root type "object"; +/// top-level "title" and "description" are removed). +pub fn schema_for_input() -> Result, String> { thread_local! { - static CACHE_FOR_INPUT: std::sync::RwLock>> = Default::default(); + static CACHE_FOR_INPUT: std::sync::RwLock, String>>> = Default::default(); }; CACHE_FOR_INPUT.with(|cache| { - if let Some(schema) = cache + if let Some(result) = cache .read() .expect("input schema cache lock poisoned") .get(&TypeId::of::()) { - schema.clone() - } else { - let schema = strip_top_level_metadata(&schema_for_type::()); - cache - .write() - .expect("input schema cache lock poisoned") - .insert(TypeId::of::(), schema.clone()); - - schema + return result.clone(); } + let result = validate_and_strip(&schema_for_type::(), "inputSchema"); + cache + .write() + .expect("input schema cache lock poisoned") + .insert(TypeId::of::(), result.clone()); + result }) } @@ -111,21 +123,7 @@ pub fn schema_for_output() -> Result(); - let result = match raw.get("type") { - Some(serde_json::Value::String(t)) if t == "object" => { - Ok(strip_top_level_metadata(&raw)) - } - Some(serde_json::Value::String(t)) => Err(format!( - "MCP specification requires tool outputSchema to have root type 'object', but found '{t}'." - )), - None => Err( - "Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string() - ), - Some(other) => Err(format!( - "Schema 'type' field has unexpected format: {other:?}. Expected \"object\"." - )), - }; + let result = validate_and_strip(&schema_for_type::(), "outputSchema"); // Cache the result (both success and error cases) cache @@ -330,4 +328,28 @@ mod tests { let schema = schema_for_output::().unwrap(); assert!(!schema.contains_key("description")); } + + #[test] + fn test_schema_for_input_rejects_primitive() { + let result = schema_for_input::(); + assert!(result.is_err()); + } + + #[test] + fn test_schema_for_input_accepts_object() { + let result = schema_for_input::(); + assert!(result.is_ok()); + } + + #[test] + fn test_schema_for_input_strips_top_level_title() { + let schema = schema_for_input::().unwrap(); + assert!(!schema.contains_key("title")); + } + + #[test] + fn test_schema_for_input_strips_top_level_description() { + let schema = schema_for_input::().unwrap(); + assert!(!schema.contains_key("description")); + } } diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index 8e291327..c2bf299c 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -250,7 +250,9 @@ where attr: Tool::new( name.into(), "", - schema_for_input::(), + schema_for_input::().unwrap_or_else(|e| { + panic!("Invalid input schema for JsonObject: {e}"); + }), ), call: self, _marker: std::marker::PhantomData, @@ -287,7 +289,12 @@ where self } pub fn parameters(mut self) -> Self { - self.attr.input_schema = schema_for_input::(); + self.attr.input_schema = schema_for_input::().unwrap_or_else(|e| { + panic!( + "Invalid input schema for `{}`: {e}", + std::any::type_name::() + ) + }); self } pub fn parameters_value(mut self, schema: serde_json::Value) -> Self { diff --git a/crates/rmcp/src/handler/server/router/tool/tool_traits.rs b/crates/rmcp/src/handler/server/router/tool/tool_traits.rs index 57977db3..b0bf9e2d 100644 --- a/crates/rmcp/src/handler/server/router/tool/tool_traits.rs +++ b/crates/rmcp/src/handler/server/router/tool/tool_traits.rs @@ -49,7 +49,14 @@ pub trait ToolBase { /// If the tool does not have any parameters, you should override this methods to return [`None`], /// and when invoked, the parameter will get default values. fn input_schema() -> Option> { - Some(schema_for_input::>()) + Some( + schema_for_input::>().unwrap_or_else(|e| { + panic!( + "Invalid input schema for ToolBase::Parameter type `{0}`: {e}", + std::any::type_name::(), + ); + }), + ) } /// Json schema for tool output. diff --git a/crates/rmcp/src/model/tool.rs b/crates/rmcp/src/model/tool.rs index 11bba529..2ed89d6a 100644 --- a/crates/rmcp/src/model/tool.rs +++ b/crates/rmcp/src/model/tool.rs @@ -330,7 +330,8 @@ impl Tool { /// Set the input schema using a type that implements JsonSchema #[cfg(feature = "server")] pub fn with_input_schema(mut self) -> Self { - self.input_schema = crate::handler::server::tool::schema_for_input::(); + self.input_schema = crate::handler::server::tool::schema_for_input::() + .unwrap_or_else(|e| panic!("Invalid input schema for tool '{}': {}", self.name, e)); self } diff --git a/crates/rmcp/tests/test_structured_output.rs b/crates/rmcp/tests/test_structured_output.rs index 7bb62e65..adbdfec5 100644 --- a/crates/rmcp/tests/test_structured_output.rs +++ b/crates/rmcp/tests/test_structured_output.rs @@ -28,6 +28,16 @@ pub struct UserInfo { pub age: u32, } +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct GreetingRequest { + pub name: String, +} + +#[derive(Serialize, Deserialize, JsonSchema)] +pub struct GetUserRequest { + pub user_id: String, +} + #[tool_handler(router = self.tool_router)] impl ServerHandler for TestServer {} @@ -64,14 +74,17 @@ impl TestServer { /// Tool that returns regular string output #[tool(name = "get-greeting", description = "Get a greeting")] - pub async fn get_greeting(&self, name: Parameters) -> String { - format!("Hello, {}!", name.0) + pub async fn get_greeting(&self, params: Parameters) -> String { + format!("Hello, {}!", params.0.name) } /// Tool that returns structured user info #[tool(name = "get-user", description = "Get user info")] - pub async fn get_user(&self, user_id: Parameters) -> Result, String> { - if user_id.0 == "123" { + pub async fn get_user( + &self, + params: Parameters, + ) -> Result, String> { + if params.0.user_id == "123" { Ok(Json(UserInfo { name: "Alice".to_string(), age: 30,