diff --git a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs index 3f5756620..54a1a9669 100644 --- a/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs +++ b/tests/ModelContextProtocol.Tests/Transport/StdioServerTransportTests.cs @@ -22,14 +22,6 @@ public StdioServerTransportTests(ITestOutputHelper testOutputHelper) InitializationTimeout = TimeSpan.FromSeconds(10), ServerInstructions = "Test Instructions" }; - - // Override the LoggerFactory to use Trace level for testing Trace-level logging - LoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => - { - builder.AddProvider(XunitLoggerProvider); - builder.AddProvider(MockLoggerProvider); - builder.SetMinimumLevel(LogLevel.Trace); - }); } [Fact(Skip="https://github.com/modelcontextprotocol/csharp-sdk/issues/143")] @@ -207,19 +199,27 @@ public async Task SendMessageAsync_Should_Preserve_Unicode_Characters() public async Task SendMessageAsync_Should_Log_At_Trace_Level() { // Arrange + var mockLoggerProvider = new MockLoggerProvider(); + using var traceLoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => + { + builder.AddProvider(XunitLoggerProvider); + builder.AddProvider(mockLoggerProvider); + builder.SetMinimumLevel(LogLevel.Trace); + }); + using var output = new MemoryStream(); await using var transport = new StreamServerTransport( new Pipe().Reader.AsStream(), output, - loggerFactory: LoggerFactory); + loggerFactory: traceLoggerFactory); // Act var message = new JsonRpcRequest { Method = "test", Id = new RequestId(44) }; await transport.SendMessageAsync(message, TestContext.Current.CancellationToken); // Assert - var traceLogMessages = MockLoggerProvider.LogMessages + var traceLogMessages = mockLoggerProvider.LogMessages .Where(x => x.LogLevel == LogLevel.Trace && x.Message.Contains("transport sending message")) .ToList(); @@ -231,6 +231,14 @@ public async Task SendMessageAsync_Should_Log_At_Trace_Level() public async Task ReadMessagesAsync_Should_Log_Received_At_Trace_Level() { // Arrange + var mockLoggerProvider = new MockLoggerProvider(); + using var traceLoggerFactory = Microsoft.Extensions.Logging.LoggerFactory.Create(builder => + { + builder.AddProvider(XunitLoggerProvider); + builder.AddProvider(mockLoggerProvider); + builder.SetMinimumLevel(LogLevel.Trace); + }); + var message = new JsonRpcRequest { Method = "test", Id = new RequestId(99) }; var json = JsonSerializer.Serialize(message, McpJsonUtilities.DefaultOptions); @@ -240,7 +248,7 @@ public async Task ReadMessagesAsync_Should_Log_Received_At_Trace_Level() await using var transport = new StreamServerTransport( input, Stream.Null, - loggerFactory: LoggerFactory); + loggerFactory: traceLoggerFactory); // Act await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes($"{json}\n"), TestContext.Current.CancellationToken); @@ -250,7 +258,7 @@ public async Task ReadMessagesAsync_Should_Log_Received_At_Trace_Level() Assert.True(canRead, "Nothing to read here from transport message reader"); // Assert - var traceLogMessages = MockLoggerProvider.LogMessages + var traceLogMessages = mockLoggerProvider.LogMessages .Where(x => x.LogLevel == LogLevel.Trace && x.Message.Contains("transport received message")) .ToList();