Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async fn main() -> anyhow::Result<()> {
}
```

The generated tool `inputSchema` is derived from the fields of `T`. The type name and documentation on `T` are ignored; only field names, field types, and field documentation are used.
The generated tool `inputSchema` and `outputSchema` are derived from the fields of `T`. The type name and documentation on `T` are ignored; only field names, field types, and field documentation are used.

When you need custom server metadata or multiple capabilities (tools + prompts), use explicit `#[tool_handler]`:

Expand Down
7 changes: 7 additions & 0 deletions crates/rmcp-macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
// if found, use the Parameters schema
syn::parse2::<Expr>(quote! {
rmcp::handler::server::common::schema_for_input::<#params_ty>()
.unwrap_or_else(|e| {
panic!(
"Invalid input schema for `{}`: {}",
std::any::type_name::<#params_ty>(),
e
)
})
})?
} else {
// if not found, use a default empty JSON schema object
Expand Down
142 changes: 89 additions & 53 deletions crates/rmcp/src/handler/server/common.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//! Common utilities shared between tool and prompt handlers

use std::{any::TypeId, collections::HashMap, sync::Arc};
use std::{
any::TypeId,
collections::HashMap,
sync::{Arc, LazyLock},
};

use schemars::JsonSchema;

Expand Down Expand Up @@ -30,12 +34,10 @@ pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
let generator = settings.into_generator();
let schema = generator.into_root_schema_for::<T>();
let object = serde_json::to_value(schema).expect("failed to serialize schema");
let object = match object {
serde_json::Value::Object(object) => object,
_ => panic!(
"Schema serialization produced non-object value: expected JSON object but got {:?}",
object
),
let serde_json::Value::Object(object) = object else {
panic!(
"Schema serialization produced non-object value: expected JSON object but got {object:?}"
);
};
let schema = Arc::new(object);
cache
Expand All @@ -48,51 +50,63 @@ pub fn schema_for_type<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
})
}

/// Generate a JSON schema for inputSchema (does not need "title" or "description" fields for the top-level object)
pub fn schema_for_input<T: JsonSchema + std::any::Any>() -> Arc<JsonObject> {
/// Validate that the schema root is `type: "object"` (per MCP spec) and strip top-level
/// `title`/`description` (the wrapper type name and doc, which are noise to the LLM).
fn validate_and_strip(raw: &Arc<JsonObject>, purpose: &str) -> Result<Arc<JsonObject>, String> {
match raw.get("type") {
Some(serde_json::Value::String(t)) if t == "object" => {
let mut object = raw.as_ref().clone();
object.remove("title");
object.remove("description");
Ok(Arc::new(object))
}
Some(serde_json::Value::String(t)) => Err(format!(
"MCP specification requires tool {purpose} to have root type 'object', but found '{t}'."
)),
None => Err(format!(
"Schema is missing 'type' field. MCP specification requires {purpose} to have root type 'object'."
)),
Some(other) => Err(format!(
"Schema 'type' field has unexpected format: {other:?}. Expected \"object\"."
)),
}
}

/// Generate, validate, and strip a JSON schema for inputSchema (must have root type "object";
/// top-level "title" and "description" are removed).
pub fn schema_for_input<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
thread_local! {
static CACHE_FOR_INPUT: std::sync::RwLock<HashMap<TypeId, Arc<JsonObject>>> = Default::default();
static CACHE_FOR_INPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
};
CACHE_FOR_INPUT.with(|cache| {
if let Some(schema) = cache
if let Some(result) = cache
.read()
.expect("input schema cache lock poisoned")
.get(&TypeId::of::<T>())
{
schema.clone()
} else {
let mut schema = schema_for_type::<T>().as_ref().clone();

// Remove unnecessary top-level fields
schema.remove("title");
schema.remove("description");

let schema = Arc::new(schema);
cache
.write()
.expect("input schema cache lock poisoned")
.insert(TypeId::of::<T>(), schema.clone());

schema
return result.clone();
}
let result = validate_and_strip(&schema_for_type::<T>(), "inputSchema");
cache
.write()
.expect("input schema cache lock poisoned")
.insert(TypeId::of::<T>(), result.clone());
result
})
}

// TODO: should be updated according to the new specifications
/// Schema used when input is empty.
pub fn schema_for_empty_input() -> Arc<JsonObject> {
std::sync::Arc::new(
serde_json::json!({
"type": "object",
"properties": {}
})
.as_object()
.unwrap()
.clone(),
)
static EMPTY: LazyLock<Arc<JsonObject>> = LazyLock::new(|| {
let mut object = JsonObject::new();
object.insert("type".into(), serde_json::json!("object"));
object.insert("properties".into(), serde_json::json!({}));
Arc::new(object)
});
EMPTY.clone()
}

/// Generate and validate a JSON schema for outputSchema (must have root type "object").
/// Generate a JSON schema for outputSchema (must have root type "object"; top-level "title" and "description" are removed)
pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObject>, String> {
thread_local! {
static CACHE_FOR_OUTPUT: std::sync::RwLock<HashMap<TypeId, Result<Arc<JsonObject>, String>>> = Default::default();
Expand All @@ -108,22 +122,8 @@ pub fn schema_for_output<T: JsonSchema + std::any::Any>() -> Result<Arc<JsonObje
return result.clone();
}

// Generate and validate schema
let schema = schema_for_type::<T>();
let result = match schema.get("type") {
Some(serde_json::Value::String(t)) if t == "object" => Ok(schema.clone()),
Some(serde_json::Value::String(t)) => Err(format!(
"MCP specification requires tool outputSchema to have root type 'object', but found '{}'.",
t
)),
None => Err(
"Schema is missing 'type' field. MCP specification requires outputSchema to have root type 'object'.".to_string()
),
Some(other) => Err(format!(
"Schema 'type' field has unexpected format: {:?}. Expected \"object\".",
other
)),
};
// Generate, validate, and strip unnecessary top-level fields
let result = validate_and_strip(&schema_for_type::<T>(), "outputSchema");

// Cache the result (both success and error cases)
cache
Expand Down Expand Up @@ -316,4 +316,40 @@ mod tests {
let result = schema_for_output::<TestObject>();
assert!(result.is_ok(),);
}

#[test]
fn test_schema_for_output_strips_top_level_title() {
let schema = schema_for_output::<TestObject>().unwrap();
assert!(!schema.contains_key("title"));
}

#[test]
fn test_schema_for_output_strips_top_level_description() {
let schema = schema_for_output::<TestObject>().unwrap();
assert!(!schema.contains_key("description"));
}

#[test]
fn test_schema_for_input_rejects_primitive() {
let result = schema_for_input::<i32>();
assert!(result.is_err());
}

#[test]
fn test_schema_for_input_accepts_object() {
let result = schema_for_input::<TestObject>();
assert!(result.is_ok());
}

#[test]
fn test_schema_for_input_strips_top_level_title() {
let schema = schema_for_input::<TestObject>().unwrap();
assert!(!schema.contains_key("title"));
}

#[test]
fn test_schema_for_input_strips_top_level_description() {
let schema = schema_for_input::<TestObject>().unwrap();
assert!(!schema.contains_key("description"));
}
}
11 changes: 9 additions & 2 deletions crates/rmcp/src/handler/server/router/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ where
attr: Tool::new(
name.into(),
"",
schema_for_input::<crate::model::JsonObject>(),
schema_for_input::<crate::model::JsonObject>().unwrap_or_else(|e| {
panic!("Invalid input schema for JsonObject: {e}");
}),
),
call: self,
_marker: std::marker::PhantomData,
Expand Down Expand Up @@ -287,7 +289,12 @@ where
self
}
pub fn parameters<T: JsonSchema + 'static>(mut self) -> Self {
self.attr.input_schema = schema_for_input::<T>();
self.attr.input_schema = schema_for_input::<T>().unwrap_or_else(|e| {
panic!(
"Invalid input schema for `{}`: {e}",
std::any::type_name::<T>()
)
});
self
}
pub fn parameters_value(mut self, schema: serde_json::Value) -> Self {
Expand Down
9 changes: 8 additions & 1 deletion crates/rmcp/src/handler/server/router/tool/tool_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ pub trait ToolBase {
/// If the tool does not have any parameters, you should override this methods to return [`None`],
/// and when invoked, the parameter will get default values.
fn input_schema() -> Option<Arc<JsonObject>> {
Some(schema_for_input::<Parameters<Self::Parameter>>())
Some(
schema_for_input::<Parameters<Self::Parameter>>().unwrap_or_else(|e| {
panic!(
"Invalid input schema for ToolBase::Parameter type `{0}`: {e}",
std::any::type_name::<Self::Parameter>(),
);
}),
)
}

/// Json schema for tool output.
Expand Down
3 changes: 2 additions & 1 deletion crates/rmcp/src/model/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,8 @@ impl Tool {
/// Set the input schema using a type that implements JsonSchema
#[cfg(feature = "server")]
pub fn with_input_schema<T: JsonSchema + 'static>(mut self) -> Self {
self.input_schema = crate::handler::server::tool::schema_for_input::<T>();
self.input_schema = crate::handler::server::tool::schema_for_input::<T>()
.unwrap_or_else(|e| panic!("Invalid input schema for tool '{}': {}", self.name, e));
self
}

Expand Down
14 changes: 11 additions & 3 deletions crates/rmcp/tests/test_list_tools_result.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#![cfg(all(feature = "server", feature = "macros", not(feature = "local")))]

use rmcp::{
Json,
handler::server::wrapper::Parameters,
model::{ListToolsResult, NumberOrString, ServerJsonRpcMessage, ServerResult},
};
Expand All @@ -14,10 +15,17 @@ struct AddRequest {
b: f64,
}

/// Result of adding two numbers.
#[derive(Debug, serde::Serialize, schemars::JsonSchema)]
struct AddResult {
/// The sum of the two numbers.
sum: f64,
}

/// Add two numbers.
#[rmcp::tool]
fn add(Parameters(AddRequest { a, b }): Parameters<AddRequest>) -> String {
(a + b).to_string()
fn add(Parameters(AddRequest { a, b }): Parameters<AddRequest>) -> Json<AddResult> {
Json(AddResult { sum: a + b })
}

#[test]
Expand All @@ -27,7 +35,7 @@ fn list_tools_result_matches_expected_json() {
let expected: serde_json::Value =
serde_json::from_slice(&expected_json).expect("invalid expected JSON fixture");

assert_eq!(add(Parameters(AddRequest { a: 1.0, b: 2.0 })), "3");
assert_eq!(add(Parameters(AddRequest { a: 1.0, b: 2.0 })).0.sum, 3.0);

let result = ListToolsResult::with_all_items(vec![add_tool_attr()]);
let response = ServerJsonRpcMessage::response(
Expand Down
14 changes: 14 additions & 0 deletions crates/rmcp/tests/test_list_tools_result/list_tools_result.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
"a",
"b"
]
},
"outputSchema": {
"$schema": "https://json-schema.org/draft/2020-12/schema",
"type": "object",
"properties": {
"sum": {
"description": "The sum of the two numbers.",
"format": "double",
"type": "number"
}
},
"required": [
"sum"
]
}
}
]
Expand Down
21 changes: 17 additions & 4 deletions crates/rmcp/tests/test_structured_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ pub struct UserInfo {
pub age: u32,
}

#[derive(Serialize, Deserialize, JsonSchema)]
pub struct GreetingRequest {
pub name: String,
}

#[derive(Serialize, Deserialize, JsonSchema)]
pub struct GetUserRequest {
pub user_id: String,
}

#[tool_handler(router = self.tool_router)]
impl ServerHandler for TestServer {}

Expand Down Expand Up @@ -64,14 +74,17 @@ impl TestServer {

/// Tool that returns regular string output
#[tool(name = "get-greeting", description = "Get a greeting")]
pub async fn get_greeting(&self, name: Parameters<String>) -> String {
format!("Hello, {}!", name.0)
pub async fn get_greeting(&self, params: Parameters<GreetingRequest>) -> String {
format!("Hello, {}!", params.0.name)
}

/// Tool that returns structured user info
#[tool(name = "get-user", description = "Get user info")]
pub async fn get_user(&self, user_id: Parameters<String>) -> Result<Json<UserInfo>, String> {
if user_id.0 == "123" {
pub async fn get_user(
&self,
params: Parameters<GetUserRequest>,
) -> Result<Json<UserInfo>, String> {
if params.0.user_id == "123" {
Ok(Json(UserInfo {
name: "Alice".to_string(),
age: 30,
Expand Down