diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Credentials.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Credentials.cpp index f2f4a0d4236..819dd552429 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Credentials.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Credentials.cpp @@ -1,50 +1,51 @@ -#include "Connection/Credentials.h" -#include "Misc/Paths.h" -#include "Misc/ConfigCacheIni.h" - -FString UCredentials::Token; -FString UCredentials::StoredKey; - -void UCredentials::Init(const FString& InFilename) -{ - StoredKey = InFilename; - LoadToken(); -} - -FString UCredentials::LoadToken() -{ - FString LoadedValue; - if (StoredKey.IsEmpty()) - { - UE_LOG(LogTemp, Warning, TEXT("UCredentials::Init has not been called before LoadToken.")); - return Token; - } - - if (GConfig->GetString(TEXT("SpacetimeDB"), *StoredKey, LoadedValue, GGameUserSettingsIni)) - { - Token = LoadedValue; - UE_LOG(LogTemp, Verbose, TEXT("UCredentials::Credentials loaded for key %s from %s."), *StoredKey, *FPaths::GetCleanFilename(GGameUserSettingsIni)); - } - else - { - UE_LOG(LogTemp, Verbose, TEXT("UCredentials::No stored credentials found for key %s."), *StoredKey); - } - - return Token; -} - -void UCredentials::SaveToken(const FString& InToken) -{ - Token = InToken; - - if (StoredKey.IsEmpty()) - { - UE_LOG(LogTemp, Warning, TEXT("UCredentials::Init has not been called before SaveToken.")); - return; - } - - GConfig->SetString(TEXT("SpacetimeDB"), *StoredKey, *Token, GGameUserSettingsIni); - - // This call writes the in-memory changes to the GGameUserSettingsIni file on the disk. - GConfig->Flush(false, GGameUserSettingsIni); -} +#include "Connection/Credentials.h" +#include "Connection/LogCategory.h" +#include "Misc/Paths.h" +#include "Misc/ConfigCacheIni.h" + +FString UCredentials::Token; +FString UCredentials::StoredKey; + +void UCredentials::Init(const FString& InFilename) +{ + StoredKey = InFilename; + LoadToken(); +} + +FString UCredentials::LoadToken() +{ + FString LoadedValue; + if (StoredKey.IsEmpty()) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("UCredentials::Init has not been called before LoadToken.")); + return Token; + } + + if (GConfig->GetString(TEXT("SpacetimeDB"), *StoredKey, LoadedValue, GGameUserSettingsIni)) + { + Token = LoadedValue; + UE_LOG(LogSpacetimeDb_Connection, Verbose, TEXT("UCredentials::Credentials loaded for key %s from %s."), *StoredKey, *FPaths::GetCleanFilename(GGameUserSettingsIni)); + } + else + { + UE_LOG(LogSpacetimeDb_Connection, Verbose, TEXT("UCredentials::No stored credentials found for key %s."), *StoredKey); + } + + return Token; +} + +void UCredentials::SaveToken(const FString& InToken) +{ + Token = InToken; + + if (StoredKey.IsEmpty()) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("UCredentials::Init has not been called before SaveToken.")); + return; + } + + GConfig->SetString(TEXT("SpacetimeDB"), *StoredKey, *Token, GGameUserSettingsIni); + + // This call writes the in-memory changes to the GGameUserSettingsIni file on the disk. + GConfig->Flush(false, GGameUserSettingsIni); +} diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp index 968d9e41c84..f66fd5d1d29 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBase.cpp @@ -1,705 +1,706 @@ -#include "Connection/DbConnectionBase.h" -#include "Connection/DbConnectionBuilder.h" -#include "Connection/Credentials.h" -#include "ModuleBindings/Types/ClientMessageType.g.h" -#include "ModuleBindings/Types/SubscribeMultiType.g.h" -#include "ModuleBindings/Types/UnsubscribeMultiType.g.h" -#include "ModuleBindings/Types/SubscribeMultiAppliedType.g.h" -#include "ModuleBindings/Types/UnsubscribeMultiAppliedType.g.h" -#include "ModuleBindings/Types/SubscriptionErrorType.g.h" -#include "ModuleBindings/Types/DatabaseUpdateType.g.h" -#include "ModuleBindings/Types/CompressableQueryUpdateType.g.h" -#include "Misc/Compression.h" -#include "Misc/ScopeLock.h" -#include "Async/Async.h" -#include "BSATN/UEBSATNHelpers.h" -#include "Connection/ProcedureFlags.h" - -UDbConnectionBase::UDbConnectionBase(const FObjectInitializer& ObjectInitializer) - : Super(ObjectInitializer) -{ - NextRequestId = 1; - NextSubscriptionId = 1; - ProcedureCallbacks = CreateDefaultSubobject(TEXT("ProcedureCallbacks")); -} - -void UDbConnectionBase::Disconnect() -{ - if (WebSocket) - { - WebSocket->Disconnect(); - } -} - -bool UDbConnectionBase::IsActive() const -{ - return WebSocket && WebSocket->IsConnected(); -} - - -bool UDbConnectionBase::TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const -{ - if (bIsIdentitySet) - { - OutIdentity = Identity; - return true; - } - - UE_LOG(LogTemp, Warning, TEXT("TryGetIdentity called before identity was set")); - return false; -} - -FSpacetimeDBConnectionId UDbConnectionBase::GetConnectionId() const -{ - return ConnectionId; -} - -bool UDbConnectionBase::SendRawMessage(const FString& Message) -{ - return WebSocket && WebSocket->SendMessage(Message); -} - -bool UDbConnectionBase::SendRawMessage(const TArray& Message) -{ - return WebSocket && WebSocket->SendMessage(Message); -} - -USubscriptionBuilderBase* UDbConnectionBase::SubscriptionBuilderBase() -{ - return NewObject(); -} - -void UDbConnectionBase::HandleWSError(const FString& Error) -{ - if (OnConnectErrorDelegate.IsBound()) - { - OnConnectErrorDelegate.Execute(Error); - } -} - -void UDbConnectionBase::HandleWSClosed(int32 /*StatusCode*/, const FString& Reason, bool /*bWasClean*/) -{ - if (OnDisconnectBaseDelegate.IsBound()) - { - OnDisconnectBaseDelegate.Execute(this, Reason); - } -} - -void UDbConnectionBase::HandleWSBinaryMessage(const TArray& Message) -{ - //tag for arrival order - const int32 Id = NextPreprocessId.GetValue(); - NextPreprocessId.Increment(); - - //do expensive work off-thread - TWeakObjectPtr WeakThis(this); - Async(EAsyncExecution::Thread, [WeakThis, Message, Id]() - { - if (!WeakThis.IsValid()) - { - return; - } - UDbConnectionBase* This = WeakThis.Get(); - - //parse the message, decompress if needed - FServerMessageType Parsed = This->PreProcessMessage(Message); - - //queue: re-order buffer - TArray Ready; - { - FScopeLock Lock(&This->PreprocessMutex); - // Move the parsed message into the map to avoid copying - This->PreprocessedMessages.Add(Id, MoveTemp(Parsed)); - //check if we can release any messages in order - while (This->PreprocessedMessages.Contains(This->NextReleaseId)) - { - Ready.Add(This->PreprocessedMessages.FindAndRemoveChecked(This->NextReleaseId)); - ++This->NextReleaseId; - } - } - //if we have any ready messages, append them to the pending messages list that is processed in Tick - if (Ready.Num() > 0) - { - FScopeLock Lock(&This->PendingMessagesMutex); - This->PendingMessages.Append(MoveTemp(Ready)); - } - }); -} - -void UDbConnectionBase::FrameTick() -{ - TArray Local; - { - FScopeLock Lock(&PendingMessagesMutex); - if (PendingMessages.Num() == 0) - { - //nothing to process, return early - return; - } - //move pending messages to local array for processing - Local = MoveTemp(PendingMessages); - PendingMessages.Empty(); - } - - //process all messages in the local array - for (const FServerMessageType& Msg : Local) - { - //process the message, this will call DbUpdate or trigger subscription events as needed - ProcessServerMessage(Msg); - } -} -void UDbConnectionBase::Tick(float DeltaTime) -{ - if (bIsAutoTicking) - { - FrameTick(); - } -} - -TStatId UDbConnectionBase::GetStatId() const -{ - // This is used by the engine to track tickables, we return a unique stat ID for this class - RETURN_QUICK_DECLARE_CYCLE_STAT(UMyTickableObject, STATGROUP_Tickables); -} - -bool UDbConnectionBase::IsTickable() const -{ - return bIsAutoTicking; -} - -bool UDbConnectionBase::IsTickableInEditor() const -{ - return bIsAutoTicking; -} - - -void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) -{ - bool bIsValid = false; - switch (Message.Tag) - { - case EServerMessageTag::InitialSubscription: - { - //@Note: This is a legacy tag, used implemented in current server version - break; - } - case EServerMessageTag::TransactionUpdate: - { - // Process a transaction update message - const FTransactionUpdateType Payload = Message.GetAsTransactionUpdate(); - - // Create a status object based on the transaction status - FSpacetimeDBStatus StatusObj; - bool bSuccess = false; - FString ErrorMessage; - if (Payload.Status.IsCommitted()) - { - bSuccess = true; - StatusObj = FSpacetimeDBStatus::Committed(FSpacetimeDBUnit()); - } - else if (Payload.Status.IsFailed()) - { - ErrorMessage = Payload.Status.GetAsFailed(); - StatusObj = FSpacetimeDBStatus::Failed(ErrorMessage); - } - else if (Payload.Status.IsOutOfEnergy()) - { - Payload.Status.GetAsOutOfEnergy(); - StatusObj = FSpacetimeDBStatus::OutOfEnergy(FSpacetimeDBUnit()); - ErrorMessage = TEXT("Out of energy"); - } - - // Process the transaction update and create a reducer event - FReducerEvent RedEvent; - RedEvent.Timestamp = Payload.Timestamp; - RedEvent.Status = StatusObj; - RedEvent.CallerIdentity = Payload.CallerIdentity; - RedEvent.CallerConnectionId = Payload.CallerConnectionId; - RedEvent.EnergyConsumed = Payload.EnergyQuantaUsed; - RedEvent.ReducerCall = Payload.ReducerCall; - - // If the status is committed, we update the database - if (bSuccess) - { - DbUpdate(Payload.Status.GetAsCommitted(), FSpacetimeDBEvent::Reducer(RedEvent)); // Update table and trigger insert/update/delete - ReducerEvent(RedEvent); // Trigger the reducer event - } - else - { - ReducerEvent(RedEvent); // Trigger the reducer event - ReducerEventFailed(RedEvent, ErrorMessage); - } - break; - } - case EServerMessageTag::TransactionUpdateLight: - { - // Process a light transaction update message - const FTransactionUpdateLightType Payload = Message.GetAsTransactionUpdateLight(); - - //@TODO: Implement light update fully - DbUpdate(Payload.Update, FSpacetimeDBEvent::UnknownTransaction(FSpacetimeDBUnit())); - - break; - } - case EServerMessageTag::IdentityToken: - { - // Process an identity token message - const FIdentityTokenType Payload = Message.GetAsIdentityToken(); - - Token = Payload.Token; - UCredentials::SaveToken(Token); - Identity = Payload.Identity; - bIsIdentitySet = true; - UE_LOG(LogTemp, Verbose, TEXT("IdentityToken: Identity set to: %s"), *Identity.ToHex()); - ConnectionId = Payload.ConnectionId; - if (OnConnectBaseDelegate.IsBound()) - { - OnConnectBaseDelegate.Execute(this, Identity, Token); - } - break; - } - case EServerMessageTag::OneOffQueryResponse: - { - //@Note: Not implemented in Rust version, skip for now here aswell - break; - } - case EServerMessageTag::SubscribeApplied: - { - //@Note: This is a legacy tag, not implemented in current server version - break; - } - case EServerMessageTag::UnsubscribeApplied: - { - //@Note: This is a legacy tag, not implemented in current server version - break; - } - case EServerMessageTag::SubscriptionError: - { - // Process a subscription error message - const FSubscriptionErrorType Payload = Message.GetAsSubscriptionError(); - if (TObjectPtr Handle = *ActiveSubscriptions.Find(Payload.QueryId.Value)) - { - if (!Handle) - { - UE_LOG(LogTemp, Error, TEXT("SubscriptionError: Null handle for QueryId %u. Error: %s"), - Payload.QueryId.Value, - *Payload.Error); - return; - } - FErrorContextBase Ctx; Ctx.Error = Payload.Error; - Handle->TriggerError(Ctx); - ActiveSubscriptions.Remove(Payload.QueryId.Value); - } - break; - } - case EServerMessageTag::SubscribeMultiApplied: - { - // Process a multi-subscription applied message - const FSubscribeMultiAppliedType Payload = Message.GetAsSubscribeMultiApplied(); - // Update the database with the subscription applied event - DbUpdate(Payload.Update, FSpacetimeDBEvent::SubscribeApplied(FSpacetimeDBUnit())); - - if (TObjectPtr Handle = *ActiveSubscriptions.Find(Payload.QueryId.Id)) - { - if (!Handle) - { - UE_LOG(LogTemp, Error, TEXT("SubscriptionError: Null handle for QueryId %u."), Payload.QueryId.Id); - return; - } - FSubscriptionEventContextBase Ctx; Ctx.Event = FSpacetimeDBEvent::SubscribeApplied(FSpacetimeDBUnit()); - Handle->TriggerApplied(Ctx); - } - - break; - } - case EServerMessageTag::UnsubscribeMultiApplied: - { - // Process a multi-unsubscription applied message - const FUnsubscribeMultiAppliedType Payload = Message.GetAsUnsubscribeMultiApplied(); - - // Update the database with the unsubscription applied event - DbUpdate(Payload.Update, FSpacetimeDBEvent::UnsubscribeApplied(FSpacetimeDBUnit())); - if (TObjectPtr Handle = *ActiveSubscriptions.Find(Payload.QueryId.Id)) - { - if (!Handle) - { - UE_LOG(LogTemp, Error, TEXT("UnsubscribeMultiApplied: Null handle for QueryId %u."), Payload.QueryId.Id); - return; - } - Handle->bEnded = true; - Handle->bActive = false; - Handle->bUnsubscribeCalled = true; - FSubscriptionEventContextBase Ctx; Ctx.Event = FSpacetimeDBEvent::UnsubscribeApplied(FSpacetimeDBUnit()); - if (Handle->EndDelegate.IsBound()) - { - Handle->EndDelegate.Execute(Ctx); - } - ActiveSubscriptions.Remove(Payload.QueryId.Id); - } - break; - } - case EServerMessageTag::ProcedureResult: - { - const FProcedureResultType Payload = Message.GetAsProcedureResult(); - FProcedureEvent ProcEvent; - ProcEvent.Status = Payload.Status; - ProcEvent.Timestamp = Payload.Timestamp; - ProcEvent.TotalHostExecutionDuration = Payload.TotalHostExecutionDuration; - ProcEvent.Success = ProcEvent.Status.IsReturned(); - TArray PayloadData; - FString ErrorMessage = ""; - if (ProcEvent.Success) - PayloadData = ProcEvent.Status.GetAsReturned(); - if (Payload.Status.IsOutOfEnergy()) - { - ErrorMessage = TEXT("Out of energy"); - } - else if (Payload.Status.IsInternalError()) - { - ErrorMessage = Payload.Status.GetAsInternalError(); - } - - ProcedureCallbacks->ResolveCallback(Payload.RequestId, FSpacetimeDBEvent::Procedure(ProcEvent), PayloadData, ProcEvent.Success); - if (!ProcEvent.Success) - { - ProcedureEventFailed(ProcEvent, ErrorMessage); - } - break; - } - default: - // Unknown tag - bail out - UE_LOG(LogTemp, Warning, TEXT("Unknown server-message tag")); - break; - } -} - -bool UDbConnectionBase::DecompressBrotli(const TArray& InData, TArray& OutData) -{ - UE_LOG(LogTemp, Error, TEXT("Brotli decompression unavilable")); - return false; -} - -bool UDbConnectionBase::DecompressGzip(const TArray& InData, TArray& OutData) -{ - if (InData.Num() < 4) - { - UE_LOG(LogTemp, Error, TEXT("Gzip data too small")); - return false; - } - - // Gzip data ends with 4 bytes indicating the uncompressed size - const uint8* SizePtr = InData.GetData() + InData.Num() - 4; - uint32 OutSize = SizePtr[0] | (SizePtr[1] << 8) | (SizePtr[2] << 16) | (SizePtr[3] << 24); - - // Validate the output size - OutData.SetNumUninitialized(OutSize); - // Attempt to decompress the Gzip data - if (!FCompression::UncompressMemory(NAME_Gzip, OutData.GetData(), OutSize, InData.GetData(), InData.Num())) - { - UE_LOG(LogTemp, Error, TEXT("Gzip decompression failed")); - return false; - } - - OutData.SetNum(OutSize); - return true; -} - -bool UDbConnectionBase::DecompressPayload(ECompressableQueryUpdateTag Variant, const TArray& In, TArray& Out) -{ - switch (Variant) - { - case ECompressableQueryUpdateTag::Uncompressed: - // No compression, just copy the data - Out = In; - return true; - case ECompressableQueryUpdateTag::Brotli: - return DecompressBrotli(In, Out); - case ECompressableQueryUpdateTag::Gzip: - return DecompressGzip(In, Out); - default: - UE_LOG(LogTemp, Error, TEXT("Unknown compression variant")); - return false; - } -} - -void UDbConnectionBase::PreProcessDatabaseUpdate(const FDatabaseUpdateType& Update) -{ - for (const FTableUpdateType& TableUpdate : Update.Tables) - { - TArray UncompressedCQUs; - for (const FCompressableQueryUpdateType& CQU : TableUpdate.Updates) - { - - // Uncompress the CQU based on its tag - FQueryUpdateType UncompressedUpdate; - switch (CQU.Tag) - { - case ECompressableQueryUpdateTag::Uncompressed: - UncompressedUpdate = CQU.GetAsUncompressed(); - break; - case ECompressableQueryUpdateTag::Brotli: - { - TArray Data = CQU.GetAsBrotli(); - TArray Dec; - if (DecompressBrotli(Data, Dec)) - { - //@Note: This will never trigger until Brotli decompression is implemented - UncompressedUpdate = UE::SpacetimeDB::Deserialize(Dec); - } - break; - } - case ECompressableQueryUpdateTag::Gzip: - { - TArray Data = CQU.GetAsGzip(); - TArray Dec; - if (DecompressGzip(Data, Dec)) - { - UncompressedUpdate = UE::SpacetimeDB::Deserialize(Dec); - } - break; - } - default: - UE_LOG(LogTemp, Error, TEXT("Unknown compression variant in CQU")); - break; - } - UncompressedCQUs.Add(FCompressableQueryUpdateType::Uncompressed(UncompressedUpdate)); - UE_LOG(LogTemp, Verbose, TEXT("Table %s Inserts:%d Deletes:%d"), *TableUpdate.TableName, UncompressedUpdate.Inserts.RowsData.Num(), UncompressedUpdate.Deletes.RowsData.Num()); - } - - // After ensuring all updates are uncompressed, attempt to deserialize rows - TSharedPtr Deserializer; - { - // Find the deserializer for this table - FScopeLock Lock(&TableDeserializersMutex); - if (TSharedPtr* Found = TableDeserializers.Find(TableUpdate.TableName)) - { - // If found, use the deserializer - Deserializer = *Found; - } - else - { - UE_LOG(LogTemp, Error, TEXT("No deserializer found for table %s"), *TableUpdate.TableName); - } - } - if (Deserializer) - { - // Preprocess the table data using the deserializer - TSharedPtr Data = Deserializer->PreProcess(UncompressedCQUs, TableUpdate.TableName); - if (Data.IsValid()) - { - // Store the preprocessed data in the mutex-protected map - FScopeLock Lock(&PreprocessedDataMutex); - FPreprocessedTableKey Key(TableUpdate.TableId, TableUpdate.TableName); - TArray>& Queue = PreprocessedTableData.FindOrAdd(Key); - Queue.Add(Data); - } - } - else - { - UE_LOG(LogTemp, Error, TEXT("Skipping table %s updates due to missing deserializer"), *TableUpdate.TableName); - } - } -} - -FServerMessageType UDbConnectionBase::PreProcessMessage(const TArray& Message) -{ - if (Message.Num() == 0) - { - UE_LOG(LogTemp, Error, TEXT("Empty message recived from server, ignored")); - return FServerMessageType{}; - } - // Check if the first byte is a valid compression tag - ECompressableQueryUpdateTag Compression = static_cast(Message[0]); - TArray CompressedPayload; - CompressedPayload.Append(Message.GetData() + 1, Message.Num() - 1); - - // Decompress the payload based on the compression tag - TArray Decompressed; - if (!DecompressPayload(Compression, CompressedPayload, Decompressed)) - { - UE_LOG(LogTemp, Error, TEXT("Failed to decompress incoming message")); - return FServerMessageType{}; - } - - // Deserialize the decompressed data into a UServerMessageType object - FServerMessageType Parsed = UE::SpacetimeDB::Deserialize(Decompressed); - - // Process it based on its tag. Messages containing rows will be deserialized into rows based on registered type and table name. - bool bValid = false; - switch (Parsed.Tag) - { - case EServerMessageTag::InitialSubscription: - { - const FInitialSubscriptionType Payload = Parsed.GetAsInitialSubscription(); - // PreProcess the initial subscription payload - PreProcessDatabaseUpdate(Payload.DatabaseUpdate); - break; - } - case EServerMessageTag::TransactionUpdate: - { - - const FTransactionUpdateType Payload = Parsed.GetAsTransactionUpdate(); - if (Payload.Status.IsCommitted()) - { - // PreProcess the database update with the committed status - PreProcessDatabaseUpdate(Payload.Status.GetAsCommitted()); - } - break; - } - case EServerMessageTag::TransactionUpdateLight: - { - //@Note: Light tag in not implemented as an option in connection builder, this will never trigger but we keep this for future compatibility - const FTransactionUpdateLightType Payload = Parsed.GetAsTransactionUpdateLight(); - // PreProcess the light transaction update - PreProcessDatabaseUpdate(Payload.Update); - break; - } - case EServerMessageTag::SubscribeMultiApplied: - { - const FSubscribeMultiAppliedType Payload = Parsed.GetAsSubscribeMultiApplied(); - PreProcessDatabaseUpdate(Payload.Update); - break; - } - case EServerMessageTag::UnsubscribeMultiApplied: - { - const FUnsubscribeMultiAppliedType Payload = Parsed.GetAsUnsubscribeMultiApplied(); - PreProcessDatabaseUpdate(Payload.Update); - break; - } - default: - break; - } - return Parsed; -} - - -int32 UDbConnectionBase::GetNextRequestId() -{ - return NextRequestId++; -} - -int32 UDbConnectionBase::GetNextSubscriptionId() -{ - return NextSubscriptionId++; -} - -void UDbConnectionBase::StartSubscription(USubscriptionHandleBase* Handle) -{ - if (!Handle) - { - UE_LOG(LogTemp, Error, TEXT("StartSubscription called with null handle")); - return; - } - - if (Handle->QuerySqls.Num() == 0) - { - UE_LOG(LogTemp, Error, TEXT("StartSubscription called with empty query list")); - return; - } - - const int32 QueryId = GetNextSubscriptionId(); - Handle->QueryId = QueryId; - Handle->ConnInternal = this; - ActiveSubscriptions.Add(QueryId, Handle); - - FSubscribeMultiType SubMsg; - SubMsg.QueryStrings = Handle->QuerySqls; - SubMsg.RequestId = GetNextRequestId(); - SubMsg.QueryId.Id = QueryId; - - FClientMessageType Msg = FClientMessageType::SubscribeMulti(SubMsg); - TArray Data = UE::SpacetimeDB::Serialize(Msg); - SendRawMessage(Data); -} - -void UDbConnectionBase::UnsubscribeInternal(USubscriptionHandleBase* Handle) -{ - if (!Handle || Handle->bEnded) - { - return; - } - - const int32 QueryId = Handle->QueryId; - FUnsubscribeMultiType MsgData; - MsgData.RequestId = GetNextRequestId(); - MsgData.QueryId.Id = QueryId; - - FClientMessageType Msg = FClientMessageType::UnsubscribeMulti(MsgData); - TArray Data = UE::SpacetimeDB::Serialize(Msg); - SendRawMessage(Data); -} - -void UDbConnectionBase::InternalCallReducer(const FString& Reducer, TArray Args, USetReducerFlagsBase* Flags) -{ - if (!WebSocket || !WebSocket->IsConnected()) - { - UE_LOG(LogTemp, Error, TEXT("Cannot call reducer, not connected to server!")); - return; - } - - uint8 FlagToUse = 0; // Default to FullUpdate - if (Flags && Flags->FlagMap.Contains(Reducer)) - { - //Select flag if set by user - ECallReducerFlags FlagFound = *Flags->FlagMap.Find(Reducer); - FlagToUse = static_cast(FlagFound); - } - - FCallReducerType MsgData; - MsgData.Reducer = Reducer; - MsgData.Args = Args; - MsgData.RequestId = GetNextRequestId(); - MsgData.Flags = FlagToUse; - - FClientMessageType Msg = FClientMessageType::CallReducer(MsgData); - TArray Data = UE::SpacetimeDB::Serialize(Msg); - SendRawMessage(Data); -} - -void UDbConnectionBase::InternalCallProcedure(const FString& ProcedureName, TArray Args, const FOnProcedureCompleteDelegate& Callback) -{ - if (!WebSocket || !WebSocket->IsConnected()) - { - UE_LOG(LogTemp, Error, TEXT("Cannot call proceduer, not connected to server!")); - return; - } - FCallProcedureType MsgData; - MsgData.Procedure = ProcedureName; - MsgData.Args = Args; - MsgData.RequestId = ProcedureCallbacks->RegisterCallback(Callback); - MsgData.Flags = static_cast(EProcedureFlags::Default); - - FClientMessageType Msg = FClientMessageType::CallProcedure(MsgData); - TArray Data = UE::SpacetimeDB::Serialize(Msg); - SendRawMessage(Data); -} - -void UDbConnectionBase::ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context) -{ - // Ensure we have a valid context for the update - TArray> Handlers; - for (const FTableUpdateType& TableUpdate : Update.Tables) - { - TSharedPtr Handler; - { - // Find the handler for this table update - FScopeLock Lock(&RegisteredTablesMutex); - if (TSharedPtr* Found = RegisteredTables.Find(TableUpdate.TableName)) - { - Handler = *Found; - } - } - if (Handler.IsValid()) - { - // Update the cache for the handler with the table update and context - Handler->UpdateCache(this, TableUpdate, Context); - Handlers.Add(Handler); - } - } - - for (TSharedPtr& Handler : Handlers) - { - // Broadcast the diff for each handler - Handler->BroadcastDiff(this, Context); - } +#include "Connection/DbConnectionBase.h" +#include "Connection/DbConnectionBuilder.h" +#include "Connection/Credentials.h" +#include "Connection/LogCategory.h" +#include "ModuleBindings/Types/ClientMessageType.g.h" +#include "ModuleBindings/Types/SubscribeMultiType.g.h" +#include "ModuleBindings/Types/UnsubscribeMultiType.g.h" +#include "ModuleBindings/Types/SubscribeMultiAppliedType.g.h" +#include "ModuleBindings/Types/UnsubscribeMultiAppliedType.g.h" +#include "ModuleBindings/Types/SubscriptionErrorType.g.h" +#include "ModuleBindings/Types/DatabaseUpdateType.g.h" +#include "ModuleBindings/Types/CompressableQueryUpdateType.g.h" +#include "Misc/Compression.h" +#include "Misc/ScopeLock.h" +#include "Async/Async.h" +#include "BSATN/UEBSATNHelpers.h" +#include "Connection/ProcedureFlags.h" + +UDbConnectionBase::UDbConnectionBase(const FObjectInitializer& ObjectInitializer) + : Super(ObjectInitializer) +{ + NextRequestId = 1; + NextSubscriptionId = 1; + ProcedureCallbacks = CreateDefaultSubobject(TEXT("ProcedureCallbacks")); +} + +void UDbConnectionBase::Disconnect() +{ + if (WebSocket) + { + WebSocket->Disconnect(); + } +} + +bool UDbConnectionBase::IsActive() const +{ + return WebSocket && WebSocket->IsConnected(); +} + + +bool UDbConnectionBase::TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const +{ + if (bIsIdentitySet) + { + OutIdentity = Identity; + return true; + } + + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("TryGetIdentity called before identity was set")); + return false; +} + +FSpacetimeDBConnectionId UDbConnectionBase::GetConnectionId() const +{ + return ConnectionId; +} + +bool UDbConnectionBase::SendRawMessage(const FString& Message) +{ + return WebSocket && WebSocket->SendMessage(Message); +} + +bool UDbConnectionBase::SendRawMessage(const TArray& Message) +{ + return WebSocket && WebSocket->SendMessage(Message); +} + +USubscriptionBuilderBase* UDbConnectionBase::SubscriptionBuilderBase() +{ + return NewObject(); +} + +void UDbConnectionBase::HandleWSError(const FString& Error) +{ + if (OnConnectErrorDelegate.IsBound()) + { + OnConnectErrorDelegate.Execute(Error); + } +} + +void UDbConnectionBase::HandleWSClosed(int32 /*StatusCode*/, const FString& Reason, bool /*bWasClean*/) +{ + if (OnDisconnectBaseDelegate.IsBound()) + { + OnDisconnectBaseDelegate.Execute(this, Reason); + } +} + +void UDbConnectionBase::HandleWSBinaryMessage(const TArray& Message) +{ + //tag for arrival order + const int32 Id = NextPreprocessId.GetValue(); + NextPreprocessId.Increment(); + + //do expensive work off-thread + TWeakObjectPtr WeakThis(this); + Async(EAsyncExecution::Thread, [WeakThis, Message, Id]() + { + if (!WeakThis.IsValid()) + { + return; + } + UDbConnectionBase* This = WeakThis.Get(); + + //parse the message, decompress if needed + FServerMessageType Parsed = This->PreProcessMessage(Message); + + //queue: re-order buffer + TArray Ready; + { + FScopeLock Lock(&This->PreprocessMutex); + // Move the parsed message into the map to avoid copying + This->PreprocessedMessages.Add(Id, MoveTemp(Parsed)); + //check if we can release any messages in order + while (This->PreprocessedMessages.Contains(This->NextReleaseId)) + { + Ready.Add(This->PreprocessedMessages.FindAndRemoveChecked(This->NextReleaseId)); + ++This->NextReleaseId; + } + } + //if we have any ready messages, append them to the pending messages list that is processed in Tick + if (Ready.Num() > 0) + { + FScopeLock Lock(&This->PendingMessagesMutex); + This->PendingMessages.Append(MoveTemp(Ready)); + } + }); +} + +void UDbConnectionBase::FrameTick() +{ + TArray Local; + { + FScopeLock Lock(&PendingMessagesMutex); + if (PendingMessages.Num() == 0) + { + //nothing to process, return early + return; + } + //move pending messages to local array for processing + Local = MoveTemp(PendingMessages); + PendingMessages.Empty(); + } + + //process all messages in the local array + for (const FServerMessageType& Msg : Local) + { + //process the message, this will call DbUpdate or trigger subscription events as needed + ProcessServerMessage(Msg); + } +} +void UDbConnectionBase::Tick(float DeltaTime) +{ + if (bIsAutoTicking) + { + FrameTick(); + } +} + +TStatId UDbConnectionBase::GetStatId() const +{ + // This is used by the engine to track tickables, we return a unique stat ID for this class + RETURN_QUICK_DECLARE_CYCLE_STAT(UMyTickableObject, STATGROUP_Tickables); +} + +bool UDbConnectionBase::IsTickable() const +{ + return bIsAutoTicking; +} + +bool UDbConnectionBase::IsTickableInEditor() const +{ + return bIsAutoTicking; +} + + +void UDbConnectionBase::ProcessServerMessage(const FServerMessageType& Message) +{ + bool bIsValid = false; + switch (Message.Tag) + { + case EServerMessageTag::InitialSubscription: + { + //@Note: This is a legacy tag, used implemented in current server version + break; + } + case EServerMessageTag::TransactionUpdate: + { + // Process a transaction update message + const FTransactionUpdateType Payload = Message.GetAsTransactionUpdate(); + + // Create a status object based on the transaction status + FSpacetimeDBStatus StatusObj; + bool bSuccess = false; + FString ErrorMessage; + if (Payload.Status.IsCommitted()) + { + bSuccess = true; + StatusObj = FSpacetimeDBStatus::Committed(FSpacetimeDBUnit()); + } + else if (Payload.Status.IsFailed()) + { + ErrorMessage = Payload.Status.GetAsFailed(); + StatusObj = FSpacetimeDBStatus::Failed(ErrorMessage); + } + else if (Payload.Status.IsOutOfEnergy()) + { + Payload.Status.GetAsOutOfEnergy(); + StatusObj = FSpacetimeDBStatus::OutOfEnergy(FSpacetimeDBUnit()); + ErrorMessage = TEXT("Out of energy"); + } + + // Process the transaction update and create a reducer event + FReducerEvent RedEvent; + RedEvent.Timestamp = Payload.Timestamp; + RedEvent.Status = StatusObj; + RedEvent.CallerIdentity = Payload.CallerIdentity; + RedEvent.CallerConnectionId = Payload.CallerConnectionId; + RedEvent.EnergyConsumed = Payload.EnergyQuantaUsed; + RedEvent.ReducerCall = Payload.ReducerCall; + + // If the status is committed, we update the database + if (bSuccess) + { + DbUpdate(Payload.Status.GetAsCommitted(), FSpacetimeDBEvent::Reducer(RedEvent)); // Update table and trigger insert/update/delete + ReducerEvent(RedEvent); // Trigger the reducer event + } + else + { + ReducerEvent(RedEvent); // Trigger the reducer event + ReducerEventFailed(RedEvent, ErrorMessage); + } + break; + } + case EServerMessageTag::TransactionUpdateLight: + { + // Process a light transaction update message + const FTransactionUpdateLightType Payload = Message.GetAsTransactionUpdateLight(); + + //@TODO: Implement light update fully + DbUpdate(Payload.Update, FSpacetimeDBEvent::UnknownTransaction(FSpacetimeDBUnit())); + + break; + } + case EServerMessageTag::IdentityToken: + { + // Process an identity token message + const FIdentityTokenType Payload = Message.GetAsIdentityToken(); + + Token = Payload.Token; + UCredentials::SaveToken(Token); + Identity = Payload.Identity; + bIsIdentitySet = true; + UE_LOG(LogSpacetimeDb_Connection, Verbose, TEXT("IdentityToken: Identity set to: %s"), *Identity.ToHex()); + ConnectionId = Payload.ConnectionId; + if (OnConnectBaseDelegate.IsBound()) + { + OnConnectBaseDelegate.Execute(this, Identity, Token); + } + break; + } + case EServerMessageTag::OneOffQueryResponse: + { + //@Note: Not implemented in Rust version, skip for now here aswell + break; + } + case EServerMessageTag::SubscribeApplied: + { + //@Note: This is a legacy tag, not implemented in current server version + break; + } + case EServerMessageTag::UnsubscribeApplied: + { + //@Note: This is a legacy tag, not implemented in current server version + break; + } + case EServerMessageTag::SubscriptionError: + { + // Process a subscription error message + const FSubscriptionErrorType Payload = Message.GetAsSubscriptionError(); + if (TObjectPtr Handle = *ActiveSubscriptions.Find(Payload.QueryId.Value)) + { + if (!Handle) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("SubscriptionError: Null handle for QueryId %u. Error: %s"), + Payload.QueryId.Value, + *Payload.Error); + return; + } + FErrorContextBase Ctx; Ctx.Error = Payload.Error; + Handle->TriggerError(Ctx); + ActiveSubscriptions.Remove(Payload.QueryId.Value); + } + break; + } + case EServerMessageTag::SubscribeMultiApplied: + { + // Process a multi-subscription applied message + const FSubscribeMultiAppliedType Payload = Message.GetAsSubscribeMultiApplied(); + // Update the database with the subscription applied event + DbUpdate(Payload.Update, FSpacetimeDBEvent::SubscribeApplied(FSpacetimeDBUnit())); + + if (TObjectPtr Handle = *ActiveSubscriptions.Find(Payload.QueryId.Id)) + { + if (!Handle) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("SubscriptionError: Null handle for QueryId %u."), Payload.QueryId.Id); + return; + } + FSubscriptionEventContextBase Ctx; Ctx.Event = FSpacetimeDBEvent::SubscribeApplied(FSpacetimeDBUnit()); + Handle->TriggerApplied(Ctx); + } + + break; + } + case EServerMessageTag::UnsubscribeMultiApplied: + { + // Process a multi-unsubscription applied message + const FUnsubscribeMultiAppliedType Payload = Message.GetAsUnsubscribeMultiApplied(); + + // Update the database with the unsubscription applied event + DbUpdate(Payload.Update, FSpacetimeDBEvent::UnsubscribeApplied(FSpacetimeDBUnit())); + if (TObjectPtr Handle = *ActiveSubscriptions.Find(Payload.QueryId.Id)) + { + if (!Handle) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UnsubscribeMultiApplied: Null handle for QueryId %u."), Payload.QueryId.Id); + return; + } + Handle->bEnded = true; + Handle->bActive = false; + Handle->bUnsubscribeCalled = true; + FSubscriptionEventContextBase Ctx; Ctx.Event = FSpacetimeDBEvent::UnsubscribeApplied(FSpacetimeDBUnit()); + if (Handle->EndDelegate.IsBound()) + { + Handle->EndDelegate.Execute(Ctx); + } + ActiveSubscriptions.Remove(Payload.QueryId.Id); + } + break; + } + case EServerMessageTag::ProcedureResult: + { + const FProcedureResultType Payload = Message.GetAsProcedureResult(); + FProcedureEvent ProcEvent; + ProcEvent.Status = Payload.Status; + ProcEvent.Timestamp = Payload.Timestamp; + ProcEvent.TotalHostExecutionDuration = Payload.TotalHostExecutionDuration; + ProcEvent.Success = ProcEvent.Status.IsReturned(); + TArray PayloadData; + FString ErrorMessage = ""; + if (ProcEvent.Success) + PayloadData = ProcEvent.Status.GetAsReturned(); + if (Payload.Status.IsOutOfEnergy()) + { + ErrorMessage = TEXT("Out of energy"); + } + else if (Payload.Status.IsInternalError()) + { + ErrorMessage = Payload.Status.GetAsInternalError(); + } + + ProcedureCallbacks->ResolveCallback(Payload.RequestId, FSpacetimeDBEvent::Procedure(ProcEvent), PayloadData, ProcEvent.Success); + if (!ProcEvent.Success) + { + ProcedureEventFailed(ProcEvent, ErrorMessage); + } + break; + } + default: + // Unknown tag - bail out + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("Unknown server-message tag")); + break; + } +} + +bool UDbConnectionBase::DecompressBrotli(const TArray& InData, TArray& OutData) +{ + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Brotli decompression unavilable")); + return false; +} + +bool UDbConnectionBase::DecompressGzip(const TArray& InData, TArray& OutData) +{ + if (InData.Num() < 4) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Gzip data too small")); + return false; + } + + // Gzip data ends with 4 bytes indicating the uncompressed size + const uint8* SizePtr = InData.GetData() + InData.Num() - 4; + uint32 OutSize = SizePtr[0] | (SizePtr[1] << 8) | (SizePtr[2] << 16) | (SizePtr[3] << 24); + + // Validate the output size + OutData.SetNumUninitialized(OutSize); + // Attempt to decompress the Gzip data + if (!FCompression::UncompressMemory(NAME_Gzip, OutData.GetData(), OutSize, InData.GetData(), InData.Num())) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Gzip decompression failed")); + return false; + } + + OutData.SetNum(OutSize); + return true; +} + +bool UDbConnectionBase::DecompressPayload(ECompressableQueryUpdateTag Variant, const TArray& In, TArray& Out) +{ + switch (Variant) + { + case ECompressableQueryUpdateTag::Uncompressed: + // No compression, just copy the data + Out = In; + return true; + case ECompressableQueryUpdateTag::Brotli: + return DecompressBrotli(In, Out); + case ECompressableQueryUpdateTag::Gzip: + return DecompressGzip(In, Out); + default: + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Unknown compression variant")); + return false; + } +} + +void UDbConnectionBase::PreProcessDatabaseUpdate(const FDatabaseUpdateType& Update) +{ + for (const FTableUpdateType& TableUpdate : Update.Tables) + { + TArray UncompressedCQUs; + for (const FCompressableQueryUpdateType& CQU : TableUpdate.Updates) + { + + // Uncompress the CQU based on its tag + FQueryUpdateType UncompressedUpdate; + switch (CQU.Tag) + { + case ECompressableQueryUpdateTag::Uncompressed: + UncompressedUpdate = CQU.GetAsUncompressed(); + break; + case ECompressableQueryUpdateTag::Brotli: + { + TArray Data = CQU.GetAsBrotli(); + TArray Dec; + if (DecompressBrotli(Data, Dec)) + { + //@Note: This will never trigger until Brotli decompression is implemented + UncompressedUpdate = UE::SpacetimeDB::Deserialize(Dec); + } + break; + } + case ECompressableQueryUpdateTag::Gzip: + { + TArray Data = CQU.GetAsGzip(); + TArray Dec; + if (DecompressGzip(Data, Dec)) + { + UncompressedUpdate = UE::SpacetimeDB::Deserialize(Dec); + } + break; + } + default: + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Unknown compression variant in CQU")); + break; + } + UncompressedCQUs.Add(FCompressableQueryUpdateType::Uncompressed(UncompressedUpdate)); + UE_LOG(LogSpacetimeDb_Connection, Verbose, TEXT("Table %s Inserts:%d Deletes:%d"), *TableUpdate.TableName, UncompressedUpdate.Inserts.RowsData.Num(), UncompressedUpdate.Deletes.RowsData.Num()); + } + + // After ensuring all updates are uncompressed, attempt to deserialize rows + TSharedPtr Deserializer; + { + // Find the deserializer for this table + FScopeLock Lock(&TableDeserializersMutex); + if (TSharedPtr* Found = TableDeserializers.Find(TableUpdate.TableName)) + { + // If found, use the deserializer + Deserializer = *Found; + } + else + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("No deserializer found for table %s"), *TableUpdate.TableName); + } + } + if (Deserializer) + { + // Preprocess the table data using the deserializer + TSharedPtr Data = Deserializer->PreProcess(UncompressedCQUs, TableUpdate.TableName); + if (Data.IsValid()) + { + // Store the preprocessed data in the mutex-protected map + FScopeLock Lock(&PreprocessedDataMutex); + FPreprocessedTableKey Key(TableUpdate.TableId, TableUpdate.TableName); + TArray>& Queue = PreprocessedTableData.FindOrAdd(Key); + Queue.Add(Data); + } + } + else + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Skipping table %s updates due to missing deserializer"), *TableUpdate.TableName); + } + } +} + +FServerMessageType UDbConnectionBase::PreProcessMessage(const TArray& Message) +{ + if (Message.Num() == 0) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Empty message recived from server, ignored")); + return FServerMessageType{}; + } + // Check if the first byte is a valid compression tag + ECompressableQueryUpdateTag Compression = static_cast(Message[0]); + TArray CompressedPayload; + CompressedPayload.Append(Message.GetData() + 1, Message.Num() - 1); + + // Decompress the payload based on the compression tag + TArray Decompressed; + if (!DecompressPayload(Compression, CompressedPayload, Decompressed)) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Failed to decompress incoming message")); + return FServerMessageType{}; + } + + // Deserialize the decompressed data into a UServerMessageType object + FServerMessageType Parsed = UE::SpacetimeDB::Deserialize(Decompressed); + + // Process it based on its tag. Messages containing rows will be deserialized into rows based on registered type and table name. + bool bValid = false; + switch (Parsed.Tag) + { + case EServerMessageTag::InitialSubscription: + { + const FInitialSubscriptionType Payload = Parsed.GetAsInitialSubscription(); + // PreProcess the initial subscription payload + PreProcessDatabaseUpdate(Payload.DatabaseUpdate); + break; + } + case EServerMessageTag::TransactionUpdate: + { + + const FTransactionUpdateType Payload = Parsed.GetAsTransactionUpdate(); + if (Payload.Status.IsCommitted()) + { + // PreProcess the database update with the committed status + PreProcessDatabaseUpdate(Payload.Status.GetAsCommitted()); + } + break; + } + case EServerMessageTag::TransactionUpdateLight: + { + //@Note: Light tag in not implemented as an option in connection builder, this will never trigger but we keep this for future compatibility + const FTransactionUpdateLightType Payload = Parsed.GetAsTransactionUpdateLight(); + // PreProcess the light transaction update + PreProcessDatabaseUpdate(Payload.Update); + break; + } + case EServerMessageTag::SubscribeMultiApplied: + { + const FSubscribeMultiAppliedType Payload = Parsed.GetAsSubscribeMultiApplied(); + PreProcessDatabaseUpdate(Payload.Update); + break; + } + case EServerMessageTag::UnsubscribeMultiApplied: + { + const FUnsubscribeMultiAppliedType Payload = Parsed.GetAsUnsubscribeMultiApplied(); + PreProcessDatabaseUpdate(Payload.Update); + break; + } + default: + break; + } + return Parsed; +} + + +int32 UDbConnectionBase::GetNextRequestId() +{ + return NextRequestId++; +} + +int32 UDbConnectionBase::GetNextSubscriptionId() +{ + return NextSubscriptionId++; +} + +void UDbConnectionBase::StartSubscription(USubscriptionHandleBase* Handle) +{ + if (!Handle) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("StartSubscription called with null handle")); + return; + } + + if (Handle->QuerySqls.Num() == 0) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("StartSubscription called with empty query list")); + return; + } + + const int32 QueryId = GetNextSubscriptionId(); + Handle->QueryId = QueryId; + Handle->ConnInternal = this; + ActiveSubscriptions.Add(QueryId, Handle); + + FSubscribeMultiType SubMsg; + SubMsg.QueryStrings = Handle->QuerySqls; + SubMsg.RequestId = GetNextRequestId(); + SubMsg.QueryId.Id = QueryId; + + FClientMessageType Msg = FClientMessageType::SubscribeMulti(SubMsg); + TArray Data = UE::SpacetimeDB::Serialize(Msg); + SendRawMessage(Data); +} + +void UDbConnectionBase::UnsubscribeInternal(USubscriptionHandleBase* Handle) +{ + if (!Handle || Handle->bEnded) + { + return; + } + + const int32 QueryId = Handle->QueryId; + FUnsubscribeMultiType MsgData; + MsgData.RequestId = GetNextRequestId(); + MsgData.QueryId.Id = QueryId; + + FClientMessageType Msg = FClientMessageType::UnsubscribeMulti(MsgData); + TArray Data = UE::SpacetimeDB::Serialize(Msg); + SendRawMessage(Data); +} + +void UDbConnectionBase::InternalCallReducer(const FString& Reducer, TArray Args, USetReducerFlagsBase* Flags) +{ + if (!WebSocket || !WebSocket->IsConnected()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Cannot call reducer, not connected to server!")); + return; + } + + uint8 FlagToUse = 0; // Default to FullUpdate + if (Flags && Flags->FlagMap.Contains(Reducer)) + { + //Select flag if set by user + ECallReducerFlags FlagFound = *Flags->FlagMap.Find(Reducer); + FlagToUse = static_cast(FlagFound); + } + + FCallReducerType MsgData; + MsgData.Reducer = Reducer; + MsgData.Args = Args; + MsgData.RequestId = GetNextRequestId(); + MsgData.Flags = FlagToUse; + + FClientMessageType Msg = FClientMessageType::CallReducer(MsgData); + TArray Data = UE::SpacetimeDB::Serialize(Msg); + SendRawMessage(Data); +} + +void UDbConnectionBase::InternalCallProcedure(const FString& ProcedureName, TArray Args, const FOnProcedureCompleteDelegate& Callback) +{ + if (!WebSocket || !WebSocket->IsConnected()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("Cannot call proceduer, not connected to server!")); + return; + } + FCallProcedureType MsgData; + MsgData.Procedure = ProcedureName; + MsgData.Args = Args; + MsgData.RequestId = ProcedureCallbacks->RegisterCallback(Callback); + MsgData.Flags = static_cast(EProcedureFlags::Default); + + FClientMessageType Msg = FClientMessageType::CallProcedure(MsgData); + TArray Data = UE::SpacetimeDB::Serialize(Msg); + SendRawMessage(Data); +} + +void UDbConnectionBase::ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context) +{ + // Ensure we have a valid context for the update + TArray> Handlers; + for (const FTableUpdateType& TableUpdate : Update.Tables) + { + TSharedPtr Handler; + { + // Find the handler for this table update + FScopeLock Lock(&RegisteredTablesMutex); + if (TSharedPtr* Found = RegisteredTables.Find(TableUpdate.TableName)) + { + Handler = *Found; + } + } + if (Handler.IsValid()) + { + // Update the cache for the handler with the table update and context + Handler->UpdateCache(this, TableUpdate, Context); + Handlers.Add(Handler); + } + } + + for (TSharedPtr& Handler : Handlers) + { + // Broadcast the diff for each handler + Handler->BroadcastDiff(this, Context); + } } \ No newline at end of file diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp index e59aaec7dd2..f933dbe7f74 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/DbConnectionBuilder.cpp @@ -1,148 +1,148 @@ -#include "Connection/DbConnectionBuilder.h" -#include "Connection/Websocket.h" -#include "Connection/DbConnectionBase.h" - - -UDbConnectionBuilderBase* UDbConnectionBuilderBase::WithUriBase(const FString& InUri) -{ - // Check if the URI contains "localhost:" and replace it with adress - if (InUri.IsEmpty()) - { - UE_LOG(LogTemp, Warning, TEXT("WithUriBase called with empty URI, not allowed")); - return this; - } - if (InUri.Contains("localhost:")) - { - FString FixedUri = InUri.Replace(TEXT("localhost"), TEXT("127.0.0.1"), ESearchCase::IgnoreCase); - Uri = FixedUri; - } - else - { - Uri = InUri; - } - return this; -} - - -UDbConnectionBuilderBase* UDbConnectionBuilderBase::WithModuleNameBase(const FString& InName) -{ - if (InName.IsEmpty()) - { - UE_LOG(LogTemp, Warning, TEXT("WithModuleNameBase called with empty module name, not allowd")); - } - ModuleName = InName; - return this; -} - - -UDbConnectionBuilderBase* UDbConnectionBuilderBase::WithTokenBase(const FString& InToken) -{ - Token = InToken; - return this; -} - -UDbConnectionBuilderBase* UDbConnectionBuilderBase::WithCompressionBase(const ESpacetimeDBCompression& InCompression) -{ - if (InCompression == ESpacetimeDBCompression::Brotli) - { - UE_LOG(LogTemp, Warning, TEXT("Brotli compression is not available in this version of SDK. Defaulting to Gzip.")); - Compression = ESpacetimeDBCompression::Gzip; - } - else - { - Compression = InCompression; - } - bCompressionSet = true; - return this; -} - -UDbConnectionBuilderBase* UDbConnectionBuilderBase::OnConnectBase(FOnConnectBaseDelegate Callback) -{ - OnConnectCallback = Callback; - return this; -} - -UDbConnectionBuilderBase* UDbConnectionBuilderBase::OnConnectErrorBase(FOnConnectErrorDelegate Callback) -{ - OnConnectErrorCallback = Callback; - return this; -} - -UDbConnectionBuilderBase* UDbConnectionBuilderBase::OnDisconnectBase(FOnDisconnectBaseDelegate Callback) -{ - OnDisconnectCallback = Callback; - return this; -} - -UDbConnectionBase* UDbConnectionBuilderBase::BuildConnection(UDbConnectionBase* Connection) -{ - - if (!Connection) - { - UE_LOG(LogTemp, Error, TEXT("BuildConnection called with null connection object")); - return nullptr; - } - - if (Uri.IsEmpty() || ModuleName.IsEmpty()) - { - UE_LOG(LogTemp, Error, TEXT("BuildConnection missing required Uri or ModuleName")); - return nullptr; - } - - FString WorkUri = Uri; - WorkUri.TrimStartAndEndInline(); - - // Normalize scheme: https->wss, http->ws, default to ws if none provided. - if (WorkUri.StartsWith(TEXT("https://"), ESearchCase::IgnoreCase)) - { - WorkUri = TEXT("wss://") + WorkUri.Mid(8); - } - else if (WorkUri.StartsWith(TEXT("http://"), ESearchCase::IgnoreCase)) - { - WorkUri = TEXT("ws://") + WorkUri.Mid(7); - } - else if (!WorkUri.StartsWith(TEXT("ws://"), ESearchCase::IgnoreCase) && - !WorkUri.StartsWith(TEXT("wss://"), ESearchCase::IgnoreCase)) - { - WorkUri = TEXT("ws://") + WorkUri; - } - - if (WorkUri.EndsWith(TEXT("/"))) - { - WorkUri.LeftChopInline(1); - } - - Connection->Uri = WorkUri; - Connection->ModuleName = ModuleName; - Connection->Token = Token; - Connection->OnConnectBaseDelegate = OnConnectCallback; - Connection->OnConnectErrorDelegate = OnConnectErrorCallback; - Connection->OnDisconnectBaseDelegate = OnDisconnectCallback; - - Connection->WebSocket = NewObject(Connection); - - //Default to Gzip compression if not set - if (!bCompressionSet) - { - Compression = ESpacetimeDBCompression::Gzip; - } - - const UEnum* CompressionEnum = StaticEnum(); - const FString CompressionName = CompressionEnum->GetNameStringByValue(static_cast(Compression)); - - // Construct the WebSocket URL using the provided URI, module name, and compression type - FString WebSocketUrl = FString::Printf(TEXT("%s/v1/database/%s/subscribe?compression=%s"), - *WorkUri, - *ModuleName, - *CompressionName); - - Connection->WebSocket->OnConnectionError.AddDynamic(Connection, &UDbConnectionBase::HandleWSError); - Connection->WebSocket->OnClosed.AddDynamic(Connection, &UDbConnectionBase::HandleWSClosed); - Connection->WebSocket->OnBinaryMessageReceived.AddDynamic(Connection, &UDbConnectionBase::HandleWSBinaryMessage); - // Set the initialization token for the WebSocket connection - Connection->WebSocket->SetInitToken(Token); - // Connect the WebSocket to the constructed URL - Connection->WebSocket->Connect(WebSocketUrl); - - return Connection; -} +#include "Connection/DbConnectionBuilder.h" +#include "Connection/Websocket.h" +#include "Connection/DbConnectionBase.h" +#include "Connection/LogCategory.h" + +UDbConnectionBuilderBase* UDbConnectionBuilderBase::WithUriBase(const FString& InUri) +{ + // Check if the URI contains "localhost:" and replace it with adress + if (InUri.IsEmpty()) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("WithUriBase called with empty URI, not allowed")); + return this; + } + if (InUri.Contains("localhost:")) + { + FString FixedUri = InUri.Replace(TEXT("localhost"), TEXT("127.0.0.1"), ESearchCase::IgnoreCase); + Uri = FixedUri; + } + else + { + Uri = InUri; + } + return this; +} + + +UDbConnectionBuilderBase* UDbConnectionBuilderBase::WithModuleNameBase(const FString& InName) +{ + if (InName.IsEmpty()) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("WithModuleNameBase called with empty module name, not allowd")); + } + ModuleName = InName; + return this; +} + + +UDbConnectionBuilderBase* UDbConnectionBuilderBase::WithTokenBase(const FString& InToken) +{ + Token = InToken; + return this; +} + +UDbConnectionBuilderBase* UDbConnectionBuilderBase::WithCompressionBase(const ESpacetimeDBCompression& InCompression) +{ + if (InCompression == ESpacetimeDBCompression::Brotli) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("Brotli compression is not available in this version of SDK. Defaulting to Gzip.")); + Compression = ESpacetimeDBCompression::Gzip; + } + else + { + Compression = InCompression; + } + bCompressionSet = true; + return this; +} + +UDbConnectionBuilderBase* UDbConnectionBuilderBase::OnConnectBase(FOnConnectBaseDelegate Callback) +{ + OnConnectCallback = Callback; + return this; +} + +UDbConnectionBuilderBase* UDbConnectionBuilderBase::OnConnectErrorBase(FOnConnectErrorDelegate Callback) +{ + OnConnectErrorCallback = Callback; + return this; +} + +UDbConnectionBuilderBase* UDbConnectionBuilderBase::OnDisconnectBase(FOnDisconnectBaseDelegate Callback) +{ + OnDisconnectCallback = Callback; + return this; +} + +UDbConnectionBase* UDbConnectionBuilderBase::BuildConnection(UDbConnectionBase* Connection) +{ + + if (!Connection) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("BuildConnection called with null connection object")); + return nullptr; + } + + if (Uri.IsEmpty() || ModuleName.IsEmpty()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("BuildConnection missing required Uri or ModuleName")); + return nullptr; + } + + FString WorkUri = Uri; + WorkUri.TrimStartAndEndInline(); + + // Normalize scheme: https->wss, http->ws, default to ws if none provided. + if (WorkUri.StartsWith(TEXT("https://"), ESearchCase::IgnoreCase)) + { + WorkUri = TEXT("wss://") + WorkUri.Mid(8); + } + else if (WorkUri.StartsWith(TEXT("http://"), ESearchCase::IgnoreCase)) + { + WorkUri = TEXT("ws://") + WorkUri.Mid(7); + } + else if (!WorkUri.StartsWith(TEXT("ws://"), ESearchCase::IgnoreCase) && + !WorkUri.StartsWith(TEXT("wss://"), ESearchCase::IgnoreCase)) + { + WorkUri = TEXT("ws://") + WorkUri; + } + + if (WorkUri.EndsWith(TEXT("/"))) + { + WorkUri.LeftChopInline(1); + } + + Connection->Uri = WorkUri; + Connection->ModuleName = ModuleName; + Connection->Token = Token; + Connection->OnConnectBaseDelegate = OnConnectCallback; + Connection->OnConnectErrorDelegate = OnConnectErrorCallback; + Connection->OnDisconnectBaseDelegate = OnDisconnectCallback; + + Connection->WebSocket = NewObject(Connection); + + //Default to Gzip compression if not set + if (!bCompressionSet) + { + Compression = ESpacetimeDBCompression::Gzip; + } + + const UEnum* CompressionEnum = StaticEnum(); + const FString CompressionName = CompressionEnum->GetNameStringByValue(static_cast(Compression)); + + // Construct the WebSocket URL using the provided URI, module name, and compression type + FString WebSocketUrl = FString::Printf(TEXT("%s/v1/database/%s/subscribe?compression=%s"), + *WorkUri, + *ModuleName, + *CompressionName); + + Connection->WebSocket->OnConnectionError.AddDynamic(Connection, &UDbConnectionBase::HandleWSError); + Connection->WebSocket->OnClosed.AddDynamic(Connection, &UDbConnectionBase::HandleWSClosed); + Connection->WebSocket->OnBinaryMessageReceived.AddDynamic(Connection, &UDbConnectionBase::HandleWSBinaryMessage); + // Set the initialization token for the WebSocket connection + Connection->WebSocket->SetInitToken(Token); + // Connect the WebSocket to the constructed URL + Connection->WebSocket->Connect(WebSocketUrl); + + return Connection; +} diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/LogCategory.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/LogCategory.cpp new file mode 100644 index 00000000000..d114ea32a2e --- /dev/null +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/LogCategory.cpp @@ -0,0 +1,3 @@ +#include "Connection/LogCategory.h" + +DEFINE_LOG_CATEGORY(LogSpacetimeDb_Connection); \ No newline at end of file diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Subscription.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Subscription.cpp index a9de0f66edd..07ab1e5b32a 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Subscription.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Subscription.cpp @@ -1,110 +1,111 @@ -#include "Connection/Subscription.h" -#include "Connection/DbConnectionBase.h" - -USubscriptionHandleBase::USubscriptionHandleBase() {} - -void USubscriptionHandleBase::Unsubscribe() -{ - if (bEnded ) - { - UE_LOG(LogTemp, Warning, TEXT("USubscriptionHandleBase::Unsubscribe called on an already ended handle. Not allowed")); - return; - } - if (bUnsubscribeCalled) - { - UE_LOG(LogTemp, Warning, TEXT("USubscriptionHandleBase::Unsubscribe called multiple times for the same handle. Not allowed")); - return; - } - - bUnsubscribeCalled = true; - - if (ConnInternal) - { - // If we have a connection, we will unsubscribe from it - ConnInternal->UnsubscribeInternal(this); - } - else - { - // If we don't have a connection, we just end the subscription - bEnded = true; - bActive = false; - if (EndDelegate.IsBound()) - { - FSubscriptionEventContextBase Ctx; - EndDelegate.Execute(Ctx); - } - } -} - -void USubscriptionHandleBase::UnsubscribeThen(FSubscriptionEventDelegate OnEnd) -{ - // If we have a connection, we will unsubscribe from it and call the end delegate when done - EndDelegate = OnEnd; - Unsubscribe(); -} - -void USubscriptionHandleBase::TriggerApplied(const FSubscriptionEventContextBase& Context) -{ - if (bEnded) - { - return; - } - bActive = true; - if (AppliedDelegate.IsBound()) - { - // If the subscription is active, we execute the applied delegate with the context - AppliedDelegate.Execute(Context); - } -} - -void USubscriptionHandleBase::TriggerError(const FErrorContextBase& Context) -{ - if (bEnded) - { - return; - } - bEnded = true; - bActive = false; - if (ErrorDelegate.IsBound()) - { - // If the subscription has an error, we execute the error delegate with the context - ErrorDelegate.Execute(Context); - } -} - -USubscriptionBuilderBase::USubscriptionBuilderBase() {} - -USubscriptionBuilderBase* USubscriptionBuilderBase::OnAppliedBase(FSubscriptionEventDelegate Callback) -{ - AppliedDelegate = Callback; - return this; -} - -USubscriptionBuilderBase* USubscriptionBuilderBase::OnErrorBase(FSubscriptionErrorDelegate Callback) -{ - ErrorDelegate = Callback; - return this; -} - -USubscriptionHandleBase* USubscriptionBuilderBase::SubscribeBase(const TArray& QuerySqls, USubscriptionHandleBase* Handle) -{ - if (!Handle) - { - UE_LOG(LogTemp, Error, TEXT("USubscriptionBuilderBase::SubscribeBase: Handle is null! Returning null handle.")); - return Handle; - } - - if (QuerySqls.Num() == 0) - { - UE_LOG(LogTemp, Warning, TEXT("SubscribeBase called with no query strings")); - } - - Handle->AppliedDelegate = AppliedDelegate; - Handle->ErrorDelegate = ErrorDelegate; - Handle->QuerySqls = QuerySqls; - // Reset delegates so builder can be reused safely - AppliedDelegate.Unbind(); - ErrorDelegate.Unbind(); - Handle->bActive = false; - return Handle; +#include "Connection/Subscription.h" +#include "Connection/DbConnectionBase.h" +#include "Connection/LogCategory.h" + +USubscriptionHandleBase::USubscriptionHandleBase() {} + +void USubscriptionHandleBase::Unsubscribe() +{ + if (bEnded ) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("USubscriptionHandleBase::Unsubscribe called on an already ended handle. Not allowed")); + return; + } + if (bUnsubscribeCalled) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("USubscriptionHandleBase::Unsubscribe called multiple times for the same handle. Not allowed")); + return; + } + + bUnsubscribeCalled = true; + + if (ConnInternal) + { + // If we have a connection, we will unsubscribe from it + ConnInternal->UnsubscribeInternal(this); + } + else + { + // If we don't have a connection, we just end the subscription + bEnded = true; + bActive = false; + if (EndDelegate.IsBound()) + { + FSubscriptionEventContextBase Ctx; + EndDelegate.Execute(Ctx); + } + } +} + +void USubscriptionHandleBase::UnsubscribeThen(FSubscriptionEventDelegate OnEnd) +{ + // If we have a connection, we will unsubscribe from it and call the end delegate when done + EndDelegate = OnEnd; + Unsubscribe(); +} + +void USubscriptionHandleBase::TriggerApplied(const FSubscriptionEventContextBase& Context) +{ + if (bEnded) + { + return; + } + bActive = true; + if (AppliedDelegate.IsBound()) + { + // If the subscription is active, we execute the applied delegate with the context + AppliedDelegate.Execute(Context); + } +} + +void USubscriptionHandleBase::TriggerError(const FErrorContextBase& Context) +{ + if (bEnded) + { + return; + } + bEnded = true; + bActive = false; + if (ErrorDelegate.IsBound()) + { + // If the subscription has an error, we execute the error delegate with the context + ErrorDelegate.Execute(Context); + } +} + +USubscriptionBuilderBase::USubscriptionBuilderBase() {} + +USubscriptionBuilderBase* USubscriptionBuilderBase::OnAppliedBase(FSubscriptionEventDelegate Callback) +{ + AppliedDelegate = Callback; + return this; +} + +USubscriptionBuilderBase* USubscriptionBuilderBase::OnErrorBase(FSubscriptionErrorDelegate Callback) +{ + ErrorDelegate = Callback; + return this; +} + +USubscriptionHandleBase* USubscriptionBuilderBase::SubscribeBase(const TArray& QuerySqls, USubscriptionHandleBase* Handle) +{ + if (!Handle) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("USubscriptionBuilderBase::SubscribeBase: Handle is null! Returning null handle.")); + return Handle; + } + + if (QuerySqls.Num() == 0) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("SubscribeBase called with no query strings")); + } + + Handle->AppliedDelegate = AppliedDelegate; + Handle->ErrorDelegate = ErrorDelegate; + Handle->QuerySqls = QuerySqls; + // Reset delegates so builder can be reused safely + AppliedDelegate.Unbind(); + ErrorDelegate.Unbind(); + Handle->bActive = false; + return Handle; } \ No newline at end of file diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp index 864c4ebba40..9c1f4525f67 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Private/Connection/Websocket.cpp @@ -1,215 +1,215 @@ - -#include "Connection/Websocket.h" -#include "WebSocketsModule.h" // Required for FWebSocketsModule -#include "SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h" -#include "ModuleBindings/Types/ServerMessageType.g.h" -#include "ModuleBindings/Types/CompressableQueryUpdateType.g.h" -#include "Misc/Compression.h" - -#include "Dom/JsonObject.h" -#include "Serialization/JsonWriter.h" -#include "Serialization/JsonSerializer.h" - -static void LogIdentityTokenHex(const FIdentityTokenType& InToken, const TCHAR* TagName) -{ - // Logs the identity token in a structured format for debugging purposes. - TSharedRef Obj = MakeShared(); - Obj->SetStringField(TEXT("__identity__"), InToken.Identity.ToHex()); - Obj->SetStringField(TEXT("token"), InToken.Token); - Obj->SetStringField(TEXT("__connection_id__"), InToken.ConnectionId.ToHex()); - - FString Json; - TSharedRef> Writer = TJsonWriterFactory<>::Create(&Json); - FJsonSerializer::Serialize(Obj, Writer); - UE_LOG(LogTemp, Log, TEXT("[%s] %s"), TagName, *Json); -} - -UWebsocketManager::UWebsocketManager() -{ - // Ensure the WebSockets module is loaded. - FModuleManager::LoadModuleChecked(TEXT("WebSockets")); -} - -void UWebsocketManager::BeginDestroy() -{ - UE_LOG(LogTemp, Log, TEXT("UWebsocketManager::BeginDestroy: Cleaning up WebSocket.")); - if (!HasAnyFlags(RF_ClassDefaultObject)) - { - Disconnect(); - } - Super::BeginDestroy(); -} - -void UWebsocketManager::Connect(const FString& ServerUrl) -{ - if (IsConnected()) - { - UE_LOG(LogTemp, Warning, TEXT("UWebsocketManager::Connect: Already connected. Disconnect first.")); - return; - } - - if (ServerUrl.IsEmpty()) - { - UE_LOG(LogTemp, Error, TEXT("UWebsocketManager::Connect called with empty URL")); - OnConnectionError.Broadcast(TEXT("Invalid server URL")); - return; - } - - // append InitToken to the connection headers if provided - TMap UpgradeHeaders; - if (!InitToken.IsEmpty()) - { - FString HeaderToken = FString::Printf(TEXT("Bearer %s"), - *InitToken); - UpgradeHeaders.Add("Authorization", HeaderToken); - } - - // using the v1.bsatn.spacetimedb protocol for WebSocket connections - const FString Protocol = "v1.bsatn.spacetimedb"; // @TODO: Implement JSON alternative, v1.json.spacetimedb - - // Create the WebSocket connection - WebSocket = FWebSocketsModule::Get().CreateWebSocket(ServerUrl, Protocol, UpgradeHeaders); - - if (!WebSocket.IsValid()) - { - UE_LOG(LogTemp, Error, TEXT("UWebsocketManager::Connect: Failed to create WebSocket connection to %s."), *ServerUrl); - OnConnectionError.Broadcast(TEXT("Failed to create WebSocket.")); - return; - } - - // Bind event handlers - WebSocket->OnConnected().AddUObject(this, &UWebsocketManager::HandleConnected); - WebSocket->OnConnectionError().AddUObject(this, &UWebsocketManager::HandleConnectionError); - WebSocket->OnMessage().AddUObject(this, &UWebsocketManager::HandleMessageReceived); - WebSocket->OnBinaryMessage().AddUObject(this, &UWebsocketManager::HandleBinaryMessageReceived); - WebSocket->OnClosed().AddUObject(this, &UWebsocketManager::HandleClosed); - - UE_LOG(LogTemp, Log, TEXT("UWebsocketManager::Connect: Connecting to %s..."), *ServerUrl); - // Start the connection process - WebSocket->Connect(); -} - -void UWebsocketManager::Disconnect() -{ - if (!WebSocket.IsValid()) - { - return; - } - - if (IsConnected()) - { - UE_LOG(LogTemp, Log, TEXT("UWebsocketManager::Disconnect: Closing WebSocket connection.")); - WebSocket->Close(); - } - - // Reset the WebSocket to allow for reconnection attempts - WebSocket.Reset(); -} - -bool UWebsocketManager::SendMessage(const FString& Message) -{ - if (!IsConnected()) - { - UE_LOG(LogTemp, Warning, TEXT("UWebsocketManager::SendMessage: WebSocket is not connected.")); - return false; - } - - if (!WebSocket.IsValid()) - { - UE_LOG(LogTemp, Error, TEXT("UWebsocketManager::SendMessage: WebSocket is not valid.")); - return false; - } - - // send the message as a UTF-8 encoded string - WebSocket->Send(Message); - return true; -} - -bool UWebsocketManager::SendMessage(const TArray& Data) -{ - if (!IsConnected()) - { - UE_LOG(LogTemp, Warning, TEXT("UWebsocketManager::SendMessage: WebSocket is not connected.")); - return false; - } - - if (!WebSocket.IsValid()) - { - UE_LOG(LogTemp, Error, TEXT("UWebsocketManager::SendMessage: WebSocket is not valid.")); - return false; - } - - // send the data as a binary message - WebSocket->Send(Data.GetData(), Data.Num(), true); - return true; -} - -bool UWebsocketManager::IsConnected() const -{ - return WebSocket.IsValid() && WebSocket->IsConnected(); -} - -void UWebsocketManager::SetInitToken(FString Token) -{ - InitToken = Token; -} - -void UWebsocketManager::HandleConnected() -{ - UE_LOG(LogTemp, Log, TEXT("UWebsocketManager: WebSocket Connected.")); - OnConnected.Broadcast(); -} - -void UWebsocketManager::HandleConnectionError(const FString& Error) -{ - UE_LOG(LogTemp, Error, TEXT("UWebsocketManager: WebSocket Connection Error: %s"), *Error); - OnConnectionError.Broadcast(Error); - // Reset on error to allow reconnection attempts - WebSocket.Reset(); -} - -void UWebsocketManager::HandleMessageReceived(const FString& Message) -{ - OnMessageReceived.Broadcast(Message); -} - -void UWebsocketManager::HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment) -{ - if (Size == 0) - { - return; - } - - // Handle binary messages, which may be fragmented - const uint8* Bytes = static_cast(Data); - - // Append this fragment to our buffer - IncompleteMessage.Append(Bytes, Size); - - // If this is the last fragment, we have the complete message - if (bIsLastFragment) - { - // We have the complete message - TArray MessageBytes = IncompleteMessage; - IncompleteMessage.Reset(); - bAwaitingBinaryFragments = false; - - // Forward the complete binary payload to listeners. - OnBinaryMessageReceived.Broadcast(MessageBytes); - } - else - { - // More fragments are coming - bAwaitingBinaryFragments = true; - } -} - -void UWebsocketManager::HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean) -{ - UE_LOG(LogTemp, Log, TEXT("UWebsocketManager: WebSocket Closed. Status: %d, Reason: %s, Clean: %s"), - StatusCode, *Reason, bWasClean ? TEXT("true") : TEXT("false")); - // Notify listeners about the closure - OnClosed.Broadcast(StatusCode, Reason, bWasClean); - // Reset on close to allow reconnection attempts - WebSocket.Reset(); + +#include "Connection/Websocket.h" +#include "WebSocketsModule.h" // Required for FWebSocketsModule +#include "SpacetimeDbSdk/Public/BSATN/UESpacetimeDB.h" +#include "ModuleBindings/Types/ServerMessageType.g.h" +#include "ModuleBindings/Types/CompressableQueryUpdateType.g.h" +#include "Misc/Compression.h" + +#include "Dom/JsonObject.h" +#include "Serialization/JsonWriter.h" +#include "Serialization/JsonSerializer.h" + +static void LogIdentityTokenHex(const FIdentityTokenType& InToken, const TCHAR* TagName) +{ + // Logs the identity token in a structured format for debugging purposes. + TSharedRef Obj = MakeShared(); + Obj->SetStringField(TEXT("__identity__"), InToken.Identity.ToHex()); + Obj->SetStringField(TEXT("token"), InToken.Token); + Obj->SetStringField(TEXT("__connection_id__"), InToken.ConnectionId.ToHex()); + + FString Json; + TSharedRef> Writer = TJsonWriterFactory<>::Create(&Json); + FJsonSerializer::Serialize(Obj, Writer); + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("[%s] %s"), TagName, *Json); +} + +UWebsocketManager::UWebsocketManager() +{ + // Ensure the WebSockets module is loaded. + FModuleManager::LoadModuleChecked(TEXT("WebSockets")); +} + +void UWebsocketManager::BeginDestroy() +{ + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager::BeginDestroy: Cleaning up WebSocket.")); + if (!HasAnyFlags(RF_ClassDefaultObject)) + { + Disconnect(); + } + Super::BeginDestroy(); +} + +void UWebsocketManager::Connect(const FString& ServerUrl) +{ + if (IsConnected()) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("UWebsocketManager::Connect: Already connected. Disconnect first.")); + return; + } + + if (ServerUrl.IsEmpty()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UWebsocketManager::Connect called with empty URL")); + OnConnectionError.Broadcast(TEXT("Invalid server URL")); + return; + } + + // append InitToken to the connection headers if provided + TMap UpgradeHeaders; + if (!InitToken.IsEmpty()) + { + FString HeaderToken = FString::Printf(TEXT("Bearer %s"), + *InitToken); + UpgradeHeaders.Add("Authorization", HeaderToken); + } + + // using the v1.bsatn.spacetimedb protocol for WebSocket connections + const FString Protocol = "v1.bsatn.spacetimedb"; // @TODO: Implement JSON alternative, v1.json.spacetimedb + + // Create the WebSocket connection + WebSocket = FWebSocketsModule::Get().CreateWebSocket(ServerUrl, Protocol, UpgradeHeaders); + + if (!WebSocket.IsValid()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UWebsocketManager::Connect: Failed to create WebSocket connection to %s."), *ServerUrl); + OnConnectionError.Broadcast(TEXT("Failed to create WebSocket.")); + return; + } + + // Bind event handlers + WebSocket->OnConnected().AddUObject(this, &UWebsocketManager::HandleConnected); + WebSocket->OnConnectionError().AddUObject(this, &UWebsocketManager::HandleConnectionError); + WebSocket->OnMessage().AddUObject(this, &UWebsocketManager::HandleMessageReceived); + WebSocket->OnBinaryMessage().AddUObject(this, &UWebsocketManager::HandleBinaryMessageReceived); + WebSocket->OnClosed().AddUObject(this, &UWebsocketManager::HandleClosed); + + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager::Connect: Connecting to %s..."), *ServerUrl); + // Start the connection process + WebSocket->Connect(); +} + +void UWebsocketManager::Disconnect() +{ + if (!WebSocket.IsValid()) + { + return; + } + + if (IsConnected()) + { + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager::Disconnect: Closing WebSocket connection.")); + WebSocket->Close(); + } + + // Reset the WebSocket to allow for reconnection attempts + WebSocket.Reset(); +} + +bool UWebsocketManager::SendMessage(const FString& Message) +{ + if (!IsConnected()) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("UWebsocketManager::SendMessage: WebSocket is not connected.")); + return false; + } + + if (!WebSocket.IsValid()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UWebsocketManager::SendMessage: WebSocket is not valid.")); + return false; + } + + // send the message as a UTF-8 encoded string + WebSocket->Send(Message); + return true; +} + +bool UWebsocketManager::SendMessage(const TArray& Data) +{ + if (!IsConnected()) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("UWebsocketManager::SendMessage: WebSocket is not connected.")); + return false; + } + + if (!WebSocket.IsValid()) + { + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UWebsocketManager::SendMessage: WebSocket is not valid.")); + return false; + } + + // send the data as a binary message + WebSocket->Send(Data.GetData(), Data.Num(), true); + return true; +} + +bool UWebsocketManager::IsConnected() const +{ + return WebSocket.IsValid() && WebSocket->IsConnected(); +} + +void UWebsocketManager::SetInitToken(FString Token) +{ + InitToken = Token; +} + +void UWebsocketManager::HandleConnected() +{ + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager: WebSocket Connected.")); + OnConnected.Broadcast(); +} + +void UWebsocketManager::HandleConnectionError(const FString& Error) +{ + UE_LOG(LogSpacetimeDb_Connection, Error, TEXT("UWebsocketManager: WebSocket Connection Error: %s"), *Error); + OnConnectionError.Broadcast(Error); + // Reset on error to allow reconnection attempts + WebSocket.Reset(); +} + +void UWebsocketManager::HandleMessageReceived(const FString& Message) +{ + OnMessageReceived.Broadcast(Message); +} + +void UWebsocketManager::HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment) +{ + if (Size == 0) + { + return; + } + + // Handle binary messages, which may be fragmented + const uint8* Bytes = static_cast(Data); + + // Append this fragment to our buffer + IncompleteMessage.Append(Bytes, Size); + + // If this is the last fragment, we have the complete message + if (bIsLastFragment) + { + // We have the complete message + TArray MessageBytes = IncompleteMessage; + IncompleteMessage.Reset(); + bAwaitingBinaryFragments = false; + + // Forward the complete binary payload to listeners. + OnBinaryMessageReceived.Broadcast(MessageBytes); + } + else + { + // More fragments are coming + bAwaitingBinaryFragments = true; + } +} + +void UWebsocketManager::HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean) +{ + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("UWebsocketManager: WebSocket Closed. Status: %d, Reason: %s, Clean: %s"), + StatusCode, *Reason, bWasClean ? TEXT("true") : TEXT("false")); + // Notify listeners about the closure + OnClosed.Broadcast(StatusCode, Reason, bWasClean); + // Reset on close to allow reconnection attempts + WebSocket.Reset(); } \ No newline at end of file diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h index fd7ab5fff02..62caa10efa4 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/DbConnectionBase.h @@ -1,403 +1,404 @@ -#pragma once - -#include "CoreMinimal.h" -#include "UObject/NoExportTypes.h" -#include "Types/Builtins.h" -#include "Websocket.h" -#include "Subscription.h" -#include "ModuleBindings/Types/ServerMessageType.g.h" -#include "DBCache/TableAppliedDiff.h" -#include "HAL/CriticalSection.h" -#include "Containers/Queue.h" -#include "HAL/ThreadSafeBool.h" -#include "BSATN/UEBSATNHelpers.h" -#include "Connection/SetReducerFlags.h" -#include "Connection/Callback.h" - -#include "DbConnectionBase.generated.h" - -// Forward declarations -class UDbConnectionBuilder; -class UProcedureCallbacks; - -/** Macro for safae way to bind delegate without needing to write Function name as an FName. */ -#define BIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ - DelegateVar.BindUFunction(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) - -/** Macro for safe way to unbind delegate without needing to write Function name as an FName. */ -#define UNBIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ - DelegateVar.Remove(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) - -/** Delegate called when the connection attempt fails. */ -DECLARE_DYNAMIC_DELEGATE_OneParam( - FOnConnectErrorDelegate, - const FString&, ErrorMessage); - -/** Called when a connection is established. */ -DECLARE_DYNAMIC_DELEGATE_ThreeParams( - FOnConnectBaseDelegate, - UDbConnectionBase*, Connection, - FSpacetimeDBIdentity, Identity, - const FString&, Token); - -/** Called when a connection closes. */ -DECLARE_DYNAMIC_DELEGATE_TwoParams( - FOnDisconnectBaseDelegate, - UDbConnectionBase*, Connection, - const FString&, Error); - - -/** Key used to index preprocessed table data without relying on row addresses */ -struct FPreprocessedTableKey -{ - uint32 TableId; - FString TableName; - - FPreprocessedTableKey() : TableId(0) {} - FPreprocessedTableKey(uint32 InId, const FString& InName) - : TableId(InId), TableName(InName) { - } - - friend bool operator==(const FPreprocessedTableKey& A, const FPreprocessedTableKey& B) - { - return A.TableId == B.TableId && A.TableName == B.TableName; - } -}; - -FORCEINLINE uint32 GetTypeHash(const FPreprocessedTableKey& Key) -{ - return HashCombine(GetTypeHash(Key.TableId), GetTypeHash(Key.TableName)); -} - -UCLASS() -class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGameObject -{ - GENERATED_BODY() - -public: - - /** The default constructor is private to prevent instantiation without using the builder. */ - explicit UDbConnectionBase(const FObjectInitializer& ObjectInitializer = FObjectInitializer::Get()); - - /** Disconnect from the server. */ - UFUNCTION(BlueprintCallable, Category="SpacetimeDB") - void Disconnect(); - - /** Check if the underlying WebSocket is connected. */ - UFUNCTION(BlueprintPure, Category="SpacetimeDB") - bool IsActive() const; - - UFUNCTION(BlueprintCallable, Category="SpacetimeDB") - void FrameTick(); - - UFUNCTION(BlueprintCallable, Category="SpacetimeDB") - void SetAutoTicking(bool bAutoTick) { bIsAutoTicking = bAutoTick; } - - /** Send a raw JSON message to the server. */ - bool SendRawMessage(const FString& Message); - /** Send a raw binary message to the server. */ - bool SendRawMessage(const TArray& Message); - - /** Get the current subscription builder. This is used to create subscriptions. */ - UFUNCTION() - USubscriptionBuilderBase* SubscriptionBuilderBase(); - - /** Get the current identity of the SpacetimeDB instance. This is used to identify the connection. */ - UFUNCTION(BlueprintPure, Category = "SpacetimeDB") - bool TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const; - - /** Get the current connection id. This is used to identify the connection. */ - UFUNCTION(BlueprintPure, Category = "SpacetimeDB") - FSpacetimeDBConnectionId GetConnectionId() const; - - // Typed reducer call helper: hides BSATN bytes from callers. - template - void CallReducerTyped(const FString& Reducer, const ArgsStruct& Args, USetReducerFlagsBase* Flags) - { - TArray Bytes = UE::SpacetimeDB::Serialize(Args); - InternalCallReducer(Reducer, MoveTemp(Bytes), Flags); - } - - template - void CallProcedureTyped(const FString& ProcedureName, const ArgsStruct& Args, const FOnProcedureCompleteDelegate& Callback) - { - TArray Bytes = UE::SpacetimeDB::Serialize(Args); - InternalCallProcedure(ProcedureName, MoveTemp(Bytes), Callback); - } - - template - void RegisterTable(const FString& TableName) - { - FScopeLock Lock(&TableDeserializersMutex); - TableDeserializers.Add(TableName, MakeShared>()); - } - - /** Internal interface for applying table updates generically */ - class ITableUpdateHandler - { - public: - virtual ~ITableUpdateHandler() {} - - /** Update the in-memory cache for the table and store the diff */ - virtual void UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context) = 0; - - /** Broadcast the previously stored diff */ - virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) = 0; - }; - - template - class TTableUpdateHandler : public ITableUpdateHandler - { - public: - explicit TTableUpdateHandler(TableClass* InTable) : Table(InTable) {} - - //** Update the in-memory cache for the table and store the diff */ - virtual void UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context) override - { - // Attempt to take preprocessed data if available - TSharedPtr> Pre; - if (Conn->TakePreprocessedTableData(Update, Pre)) - { - // If preprocessed data is available, use it to update the table - LastDiff = Table->Update(Pre->Inserts, Pre->Deletes); - } - else - { - // If no preprocessed data, process the update directly. Backup - UE_LOG(LogTemp, Warning, TEXT("No preprocessed data for table update. Processing directly.")); - TArray> Inserts, Deletes; - UE::SpacetimeDB::ProcessTableUpdateWithBsatn(Update, Inserts, Deletes); - LastDiff = Table->Update(Inserts, Deletes); - } - } - //** Broadcast the last stored diff to the table's delegates */ - virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) override - { - EventContext& Ctx = *reinterpret_cast(Context); - Conn->BroadcastDiff(Table, LastDiff, Ctx); - } - - private: - TableClass* Table; - FTableAppliedDiff LastDiff; - }; - //** Register a table with the connection. This will allow the connection to handle updates for the table. - template - void RegisterTable(const FString& TableName, TableClass* Table) - { - RegisterTable(TableName); - FScopeLock Lock(&RegisteredTablesMutex); - RegisteredTables.Add(TableName, MakeShared>(Table)); - } - //** Take preprocessed table row data. */ - template - bool TakePreprocessedTableData(const FTableUpdateType& Update, TSharedPtr>& OutData) - { - FScopeLock Lock(&PreprocessedDataMutex); - FPreprocessedTableKey Key(Update.TableId, Update.TableName); - if (TArray>* Found = PreprocessedTableData.Find(Key)) - { - if (Found->Num() > 0) - { - OutData = StaticCastSharedPtr>((*Found)[0]); - Found->RemoveAt(0); - if (Found->Num() == 0) - { - PreprocessedTableData.Remove(Key); - } - return OutData.IsValid(); - } - } - return false; - } - - -protected: - - friend class UDbConnectionBuilderBase; - friend class UDbConnectionBuilder; - friend class USubscriptionHandleBase; - friend class USubscriptionBuilder; - friend class URemoteReducers; - - /** Allow derived classes to override the delegates used when connecting */ - void SetOnConnectDelegate(const FOnConnectBaseDelegate& Delegate) { OnConnectBaseDelegate = Delegate; } - void SetOnDisconnectDelegate(const FOnDisconnectBaseDelegate& Delegate) { OnDisconnectBaseDelegate = Delegate; } - - UFUNCTION() - void HandleWSError(const FString& Error); - UFUNCTION() - void HandleWSClosed(int32 StatusCode, const FString& Reason, bool bWasClean); - UFUNCTION() - void HandleWSBinaryMessage(const TArray& Message); - - virtual void Tick(float DeltaTime) override; - - virtual TStatId GetStatId() const override; - - virtual bool IsTickable() const override; - - virtual bool IsTickableInEditor() const override; - - /** Internal handler that processes a single server message. */ - void ProcessServerMessage(const FServerMessageType& Message); - void PreProcessDatabaseUpdate(const FDatabaseUpdateType& Update); - /** Decompress and parse a raw message. */ - FServerMessageType PreProcessMessage(const TArray& Message); - bool DecompressPayload(ECompressableQueryUpdateTag Variant, const TArray& In, TArray& Out); - bool DecompressGzip(const TArray& InData, TArray& OutData); - bool DecompressBrotli(const TArray& InData, TArray& OutData); - - /** Pending messages awaiting processing on the game thread. */ - TArray PendingMessages; - - /** Mutex protecting access to PendingMessages. */ - FCriticalSection PendingMessagesMutex; - - /** Map of preprocessed messages keyed by their sequential id. */ - TMap PreprocessedMessages; - - /** Protects PreprocessedMessages and PendingMessages ordering state. */ - FCriticalSection PreprocessMutex; - - /** Counter for assigning ids to incoming messages. */ - FThreadSafeCounter NextPreprocessId; - - /** Id of the next message expected to be released. */ - int32 NextReleaseId = 0; - - // Map of table name to row deserializer - TMap> TableDeserializers; - FCriticalSection TableDeserializersMutex; - - // Map from table update pointer to preprocessed data - TMap>> PreprocessedTableData; - FCriticalSection PreprocessedDataMutex; - - // Map of table name to generic table update handler - TMap> RegisteredTables; - FCriticalSection RegisteredTablesMutex; - - - /** Start a subscription. This will add the subscription to the active list and send a subscribe message to the server. */ - void StartSubscription(USubscriptionHandleBase* Handle); - /** Unsubscribe from a subscription. This will remove the subscription from the active list and send an unsubscribe message to the server. */ - void UnsubscribeInternal(USubscriptionHandleBase* Handle); - - /** Call a reducer on the connected SpacetimeDB instance. */ - void InternalCallReducer(const FString& Reducer, TArray Args, USetReducerFlagsBase* Flags); - - /** Call a reducer on the connected SpacetimeDB instance. */ - void InternalCallProcedure(const FString& ProcedureName, TArray Args, const FOnProcedureCompleteDelegate& Callback); - - /** - * Update function to apply database changes. - * Must be implemented by child classes. - * @param Update - Struct containing update data. - */ - virtual void DbUpdate(const FDatabaseUpdateType& Update, const FSpacetimeDBEvent& Event) {}; - - /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ - virtual void ReducerEvent(const FReducerEvent& Event) {}; - - /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ - virtual void ReducerEventFailed(const FReducerEvent& Event, const FString ErrorMessage) {}; - - /** Event handler for procedure events. This can should overridden by child classes to handle specific procedure events. */ - virtual void ProcedureEventFailed(const FProcedureEvent& Event, const FString ErrorMessage) {}; - - /** Event handler for error events. This can should overridden by child classes to handle specific error events. */ - virtual void TriggerError(const FString& ErrorMessage) {}; - - /** Event handler for subscription events. This can should overridden by child classes to handle specific subscription events. */ - virtual void TriggerSubscription() {}; - - /** Apply updates for all registered tables using the provided context pointer */ - void ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context); - - /** Called when a subscription is updated. */ - UPROPERTY() - TMap> ActiveSubscriptions; - - UPROPERTY() - TObjectPtr ProcedureCallbacks; - /** Get the next request id for a message. This is used to track requests and responses. */ - int32 NextRequestId; - /** Get the next subscription id for a subscription. This is used to track subscriptions and their responses. */ - int32 NextSubscriptionId; - /** Get the next request id for a message. This is used to track requests and responses. */ - int32 GetNextRequestId(); - /** Get the next subscription id for a subscription. This is used to track subscriptions and their responses. */ - int32 GetNextSubscriptionId(); - - /** The WebSocket manager used to connect to the server. */ - UPROPERTY() - UWebsocketManager* WebSocket = nullptr; - - /** The URI of the SpacetimeDB server to connect to. */ - UPROPERTY() - FString Uri; - /** The module name to connect to. This is used to identify the SpacetimeDB instance. */ - UPROPERTY() - FString ModuleName; - /** The token used to authenticate the connection. */ - UPROPERTY() - FString Token; - - /** The identity of the SpacetimeDB instance. This is used to identify the connection. */ - UPROPERTY() - FSpacetimeDBIdentity Identity; - UPROPERTY() - /** Whether the identity has been set. This is used to prevent multiple identity sets. */ - bool bIsIdentitySet = false; - /** The connection id of the SpacetimeDB instance. This is used to identify the connection. */ - UPROPERTY() - FSpacetimeDBConnectionId ConnectionId; - - UPROPERTY() - bool bIsAutoTicking = false; - - UPROPERTY() - FOnConnectErrorDelegate OnConnectErrorDelegate; - UPROPERTY() - FOnDisconnectBaseDelegate OnDisconnectBaseDelegate; - UPROPERTY() - FOnConnectBaseDelegate OnConnectBaseDelegate; - - /** Called when the connection is established. */ - template - void BroadcastDiff(TableClass* Table, const FTableAppliedDiff& Diff, const EventContext& Context) - { - if (!Table) return; - - // Broadcast the diff to the table's delegates - if (Table->OnInsert.IsBound()) - { - for (const TPair, RowType>& Pair : Diff.Inserts) - { - Table->OnInsert.Broadcast(Context, Pair.Value); - } - } - - // If the table has a delete delegate, broadcast deletes - if (Table->OnDelete.IsBound()) - { - for (const TPair, RowType>& Pair : Diff.Deletes) - { - Table->OnDelete.Broadcast(Context, Pair.Value); - } - } - - // If the table has an update delegate, broadcast updates - if (Table->OnUpdate.IsBound()) - { - int32 Count = FMath::Min(Diff.UpdateDeletes.Num(), Diff.UpdateInserts.Num()); - for (int32 Index = 0; Index < Count; ++Index) - { - const RowType& OldRow = Diff.UpdateDeletes[Index]; - const RowType& NewRow = Diff.UpdateInserts[Index]; - Table->OnUpdate.Broadcast(Context, OldRow, NewRow); - } - } - } -}; +#pragma once + +#include "CoreMinimal.h" +#include "UObject/NoExportTypes.h" +#include "Types/Builtins.h" +#include "Websocket.h" +#include "Subscription.h" +#include "ModuleBindings/Types/ServerMessageType.g.h" +#include "DBCache/TableAppliedDiff.h" +#include "HAL/CriticalSection.h" +#include "Containers/Queue.h" +#include "HAL/ThreadSafeBool.h" +#include "BSATN/UEBSATNHelpers.h" +#include "Connection/SetReducerFlags.h" +#include "Connection/Callback.h" +#include "LogCategory.h" + +#include "DbConnectionBase.generated.h" + +// Forward declarations +class UDbConnectionBuilder; +class UProcedureCallbacks; + +/** Macro for safae way to bind delegate without needing to write Function name as an FName. */ +#define BIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ + DelegateVar.BindUFunction(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) + +/** Macro for safe way to unbind delegate without needing to write Function name as an FName. */ +#define UNBIND_DELEGATE_SAFE(DelegateVar, Object, ClassType, FunctionName) \ + DelegateVar.Remove(Object, GET_FUNCTION_NAME_CHECKED(ClassType, FunctionName)) + +/** Delegate called when the connection attempt fails. */ +DECLARE_DYNAMIC_DELEGATE_OneParam( + FOnConnectErrorDelegate, + const FString&, ErrorMessage); + +/** Called when a connection is established. */ +DECLARE_DYNAMIC_DELEGATE_ThreeParams( + FOnConnectBaseDelegate, + UDbConnectionBase*, Connection, + FSpacetimeDBIdentity, Identity, + const FString&, Token); + +/** Called when a connection closes. */ +DECLARE_DYNAMIC_DELEGATE_TwoParams( + FOnDisconnectBaseDelegate, + UDbConnectionBase*, Connection, + const FString&, Error); + + +/** Key used to index preprocessed table data without relying on row addresses */ +struct FPreprocessedTableKey +{ + uint32 TableId; + FString TableName; + + FPreprocessedTableKey() : TableId(0) {} + FPreprocessedTableKey(uint32 InId, const FString& InName) + : TableId(InId), TableName(InName) { + } + + friend bool operator==(const FPreprocessedTableKey& A, const FPreprocessedTableKey& B) + { + return A.TableId == B.TableId && A.TableName == B.TableName; + } +}; + +FORCEINLINE uint32 GetTypeHash(const FPreprocessedTableKey& Key) +{ + return HashCombine(GetTypeHash(Key.TableId), GetTypeHash(Key.TableName)); +} + +UCLASS() +class SPACETIMEDBSDK_API UDbConnectionBase : public UObject, public FTickableGameObject +{ + GENERATED_BODY() + +public: + + /** The default constructor is private to prevent instantiation without using the builder. */ + explicit UDbConnectionBase(const FObjectInitializer& ObjectInitializer = FObjectInitializer::Get()); + + /** Disconnect from the server. */ + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") + void Disconnect(); + + /** Check if the underlying WebSocket is connected. */ + UFUNCTION(BlueprintPure, Category="SpacetimeDB") + bool IsActive() const; + + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") + void FrameTick(); + + UFUNCTION(BlueprintCallable, Category="SpacetimeDB") + void SetAutoTicking(bool bAutoTick) { bIsAutoTicking = bAutoTick; } + + /** Send a raw JSON message to the server. */ + bool SendRawMessage(const FString& Message); + /** Send a raw binary message to the server. */ + bool SendRawMessage(const TArray& Message); + + /** Get the current subscription builder. This is used to create subscriptions. */ + UFUNCTION() + USubscriptionBuilderBase* SubscriptionBuilderBase(); + + /** Get the current identity of the SpacetimeDB instance. This is used to identify the connection. */ + UFUNCTION(BlueprintPure, Category = "SpacetimeDB") + bool TryGetIdentity(FSpacetimeDBIdentity& OutIdentity) const; + + /** Get the current connection id. This is used to identify the connection. */ + UFUNCTION(BlueprintPure, Category = "SpacetimeDB") + FSpacetimeDBConnectionId GetConnectionId() const; + + // Typed reducer call helper: hides BSATN bytes from callers. + template + void CallReducerTyped(const FString& Reducer, const ArgsStruct& Args, USetReducerFlagsBase* Flags) + { + TArray Bytes = UE::SpacetimeDB::Serialize(Args); + InternalCallReducer(Reducer, MoveTemp(Bytes), Flags); + } + + template + void CallProcedureTyped(const FString& ProcedureName, const ArgsStruct& Args, const FOnProcedureCompleteDelegate& Callback) + { + TArray Bytes = UE::SpacetimeDB::Serialize(Args); + InternalCallProcedure(ProcedureName, MoveTemp(Bytes), Callback); + } + + template + void RegisterTable(const FString& TableName) + { + FScopeLock Lock(&TableDeserializersMutex); + TableDeserializers.Add(TableName, MakeShared>()); + } + + /** Internal interface for applying table updates generically */ + class ITableUpdateHandler + { + public: + virtual ~ITableUpdateHandler() {} + + /** Update the in-memory cache for the table and store the diff */ + virtual void UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context) = 0; + + /** Broadcast the previously stored diff */ + virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) = 0; + }; + + template + class TTableUpdateHandler : public ITableUpdateHandler + { + public: + explicit TTableUpdateHandler(TableClass* InTable) : Table(InTable) {} + + //** Update the in-memory cache for the table and store the diff */ + virtual void UpdateCache(UDbConnectionBase* Conn, const FTableUpdateType& Update, void* Context) override + { + // Attempt to take preprocessed data if available + TSharedPtr> Pre; + if (Conn->TakePreprocessedTableData(Update, Pre)) + { + // If preprocessed data is available, use it to update the table + LastDiff = Table->Update(Pre->Inserts, Pre->Deletes); + } + else + { + // If no preprocessed data, process the update directly. Backup + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("No preprocessed data for table update. Processing directly.")); + TArray> Inserts, Deletes; + UE::SpacetimeDB::ProcessTableUpdateWithBsatn(Update, Inserts, Deletes); + LastDiff = Table->Update(Inserts, Deletes); + } + } + //** Broadcast the last stored diff to the table's delegates */ + virtual void BroadcastDiff(UDbConnectionBase* Conn, void* Context) override + { + EventContext& Ctx = *reinterpret_cast(Context); + Conn->BroadcastDiff(Table, LastDiff, Ctx); + } + + private: + TableClass* Table; + FTableAppliedDiff LastDiff; + }; + //** Register a table with the connection. This will allow the connection to handle updates for the table. + template + void RegisterTable(const FString& TableName, TableClass* Table) + { + RegisterTable(TableName); + FScopeLock Lock(&RegisteredTablesMutex); + RegisteredTables.Add(TableName, MakeShared>(Table)); + } + //** Take preprocessed table row data. */ + template + bool TakePreprocessedTableData(const FTableUpdateType& Update, TSharedPtr>& OutData) + { + FScopeLock Lock(&PreprocessedDataMutex); + FPreprocessedTableKey Key(Update.TableId, Update.TableName); + if (TArray>* Found = PreprocessedTableData.Find(Key)) + { + if (Found->Num() > 0) + { + OutData = StaticCastSharedPtr>((*Found)[0]); + Found->RemoveAt(0); + if (Found->Num() == 0) + { + PreprocessedTableData.Remove(Key); + } + return OutData.IsValid(); + } + } + return false; + } + + +protected: + + friend class UDbConnectionBuilderBase; + friend class UDbConnectionBuilder; + friend class USubscriptionHandleBase; + friend class USubscriptionBuilder; + friend class URemoteReducers; + + /** Allow derived classes to override the delegates used when connecting */ + void SetOnConnectDelegate(const FOnConnectBaseDelegate& Delegate) { OnConnectBaseDelegate = Delegate; } + void SetOnDisconnectDelegate(const FOnDisconnectBaseDelegate& Delegate) { OnDisconnectBaseDelegate = Delegate; } + + UFUNCTION() + void HandleWSError(const FString& Error); + UFUNCTION() + void HandleWSClosed(int32 StatusCode, const FString& Reason, bool bWasClean); + UFUNCTION() + void HandleWSBinaryMessage(const TArray& Message); + + virtual void Tick(float DeltaTime) override; + + virtual TStatId GetStatId() const override; + + virtual bool IsTickable() const override; + + virtual bool IsTickableInEditor() const override; + + /** Internal handler that processes a single server message. */ + void ProcessServerMessage(const FServerMessageType& Message); + void PreProcessDatabaseUpdate(const FDatabaseUpdateType& Update); + /** Decompress and parse a raw message. */ + FServerMessageType PreProcessMessage(const TArray& Message); + bool DecompressPayload(ECompressableQueryUpdateTag Variant, const TArray& In, TArray& Out); + bool DecompressGzip(const TArray& InData, TArray& OutData); + bool DecompressBrotli(const TArray& InData, TArray& OutData); + + /** Pending messages awaiting processing on the game thread. */ + TArray PendingMessages; + + /** Mutex protecting access to PendingMessages. */ + FCriticalSection PendingMessagesMutex; + + /** Map of preprocessed messages keyed by their sequential id. */ + TMap PreprocessedMessages; + + /** Protects PreprocessedMessages and PendingMessages ordering state. */ + FCriticalSection PreprocessMutex; + + /** Counter for assigning ids to incoming messages. */ + FThreadSafeCounter NextPreprocessId; + + /** Id of the next message expected to be released. */ + int32 NextReleaseId = 0; + + // Map of table name to row deserializer + TMap> TableDeserializers; + FCriticalSection TableDeserializersMutex; + + // Map from table update pointer to preprocessed data + TMap>> PreprocessedTableData; + FCriticalSection PreprocessedDataMutex; + + // Map of table name to generic table update handler + TMap> RegisteredTables; + FCriticalSection RegisteredTablesMutex; + + + /** Start a subscription. This will add the subscription to the active list and send a subscribe message to the server. */ + void StartSubscription(USubscriptionHandleBase* Handle); + /** Unsubscribe from a subscription. This will remove the subscription from the active list and send an unsubscribe message to the server. */ + void UnsubscribeInternal(USubscriptionHandleBase* Handle); + + /** Call a reducer on the connected SpacetimeDB instance. */ + void InternalCallReducer(const FString& Reducer, TArray Args, USetReducerFlagsBase* Flags); + + /** Call a reducer on the connected SpacetimeDB instance. */ + void InternalCallProcedure(const FString& ProcedureName, TArray Args, const FOnProcedureCompleteDelegate& Callback); + + /** + * Update function to apply database changes. + * Must be implemented by child classes. + * @param Update - Struct containing update data. + */ + virtual void DbUpdate(const FDatabaseUpdateType& Update, const FSpacetimeDBEvent& Event) {}; + + /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ + virtual void ReducerEvent(const FReducerEvent& Event) {}; + + /** Event handler for reducer events. This can should overridden by child classes to handle specific reducer events. */ + virtual void ReducerEventFailed(const FReducerEvent& Event, const FString ErrorMessage) {}; + + /** Event handler for procedure events. This can should overridden by child classes to handle specific procedure events. */ + virtual void ProcedureEventFailed(const FProcedureEvent& Event, const FString ErrorMessage) {}; + + /** Event handler for error events. This can should overridden by child classes to handle specific error events. */ + virtual void TriggerError(const FString& ErrorMessage) {}; + + /** Event handler for subscription events. This can should overridden by child classes to handle specific subscription events. */ + virtual void TriggerSubscription() {}; + + /** Apply updates for all registered tables using the provided context pointer */ + void ApplyRegisteredTableUpdates(const FDatabaseUpdateType& Update, void* Context); + + /** Called when a subscription is updated. */ + UPROPERTY() + TMap> ActiveSubscriptions; + + UPROPERTY() + TObjectPtr ProcedureCallbacks; + /** Get the next request id for a message. This is used to track requests and responses. */ + int32 NextRequestId; + /** Get the next subscription id for a subscription. This is used to track subscriptions and their responses. */ + int32 NextSubscriptionId; + /** Get the next request id for a message. This is used to track requests and responses. */ + int32 GetNextRequestId(); + /** Get the next subscription id for a subscription. This is used to track subscriptions and their responses. */ + int32 GetNextSubscriptionId(); + + /** The WebSocket manager used to connect to the server. */ + UPROPERTY() + UWebsocketManager* WebSocket = nullptr; + + /** The URI of the SpacetimeDB server to connect to. */ + UPROPERTY() + FString Uri; + /** The module name to connect to. This is used to identify the SpacetimeDB instance. */ + UPROPERTY() + FString ModuleName; + /** The token used to authenticate the connection. */ + UPROPERTY() + FString Token; + + /** The identity of the SpacetimeDB instance. This is used to identify the connection. */ + UPROPERTY() + FSpacetimeDBIdentity Identity; + UPROPERTY() + /** Whether the identity has been set. This is used to prevent multiple identity sets. */ + bool bIsIdentitySet = false; + /** The connection id of the SpacetimeDB instance. This is used to identify the connection. */ + UPROPERTY() + FSpacetimeDBConnectionId ConnectionId; + + UPROPERTY() + bool bIsAutoTicking = false; + + UPROPERTY() + FOnConnectErrorDelegate OnConnectErrorDelegate; + UPROPERTY() + FOnDisconnectBaseDelegate OnDisconnectBaseDelegate; + UPROPERTY() + FOnConnectBaseDelegate OnConnectBaseDelegate; + + /** Called when the connection is established. */ + template + void BroadcastDiff(TableClass* Table, const FTableAppliedDiff& Diff, const EventContext& Context) + { + if (!Table) return; + + // Broadcast the diff to the table's delegates + if (Table->OnInsert.IsBound()) + { + for (const TPair, RowType>& Pair : Diff.Inserts) + { + Table->OnInsert.Broadcast(Context, Pair.Value); + } + } + + // If the table has a delete delegate, broadcast deletes + if (Table->OnDelete.IsBound()) + { + for (const TPair, RowType>& Pair : Diff.Deletes) + { + Table->OnDelete.Broadcast(Context, Pair.Value); + } + } + + // If the table has an update delegate, broadcast updates + if (Table->OnUpdate.IsBound()) + { + int32 Count = FMath::Min(Diff.UpdateDeletes.Num(), Diff.UpdateInserts.Num()); + for (int32 Index = 0; Index < Count; ++Index) + { + const RowType& OldRow = Diff.UpdateDeletes[Index]; + const RowType& NewRow = Diff.UpdateInserts[Index]; + Table->OnUpdate.Broadcast(Context, OldRow, NewRow); + } + } + } +}; diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/LogCategory.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/LogCategory.h new file mode 100644 index 00000000000..7881f4657ab --- /dev/null +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/LogCategory.h @@ -0,0 +1,4 @@ + +#pragma once + +SPACETIMEDBSDK_API DECLARE_LOG_CATEGORY_EXTERN(LogSpacetimeDb_Connection, Log, Log); \ No newline at end of file diff --git a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h index ec525c47d50..8ac2b396416 100644 --- a/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h +++ b/sdks/unreal/src/SpacetimeDbSdk/Source/SpacetimeDbSdk/Public/Connection/Websocket.h @@ -1,174 +1,175 @@ -#pragma once - -#include "CoreMinimal.h" -#include "IWebSocket.h" -#include "ModuleBindings/Types/ServerMessageType.g.h" -#include "ModuleBindings/Types/CompressableQueryUpdateType.g.h" -#include "JsonObjectConverter.h" // for JSON debugging helpers -#include "Async/Async.h" -#include "HAL/CriticalSection.h" -#include "Misc/ScopeLock.h" - - -#include "Websocket.generated.h" - -/** Delegate broadcast when a connection is successfully established */ -DECLARE_DYNAMIC_MULTICAST_DELEGATE(FOnWebSocketConnected); -/** Delegate broadcast on connection error */ -DECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParam(FOnWebSocketConnectionError, const FString&, ErrorMessage); -/** Delegate broadcast when a text message is received */ -DECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParam(FOnWebSocketMessageReceived, const FString&, Message); -/** Delegate broadcast when the socket closes */ -DECLARE_DYNAMIC_MULTICAST_DELEGATE_ThreeParams(FOnWebSocketClosed, int32, StatusCode, const FString&, Reason, bool, bWasClean); -/** Delegate broadcast when binary data is received */ -DECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParam(FOnWebSocketBinaryMessageReceived, const TArray&, Data); - - -/** - * Manages the low-level WebSocket connection to the SpacetimeDB server. - * Handles connecting, disconnecting, sending messages, and receiving messages. - */ -UCLASS(BlueprintType) -class SPACETIMEDBSDK_API UWebsocketManager : public UObject -{ - GENERATED_BODY() - -public: - UWebsocketManager(); - - virtual void BeginDestroy() override; - - /** - * Connects to the WebSocket server at the given URL. - * @param ServerUrl The URL of the WebSocket server. - */ - void Connect(const FString& ServerUrl); - - /** - * Disconnects from the WebSocket server. - */ - void Disconnect(); - - /** - * Sends a message to the WebSocket server. - * @param Message The message to send. - * @return True if the message was sent successfully, false otherwise. - */ - bool SendMessage(const FString& Message); - - /** - * Sends binary data to the WebSocket server. - * @param Data The bytes to send. - * @return True if the message was sent successfully, false otherwise. - */ - bool SendMessage(const TArray& Data); - - /** - * Checks if the WebSocket connection is currently active. - * @return True if connected, false otherwise. - */ - bool IsConnected() const; - - /** - * Sets the initial auth token used when connecting. - * @param Token JWT or session token expected by the server. - */ - void SetInitToken(FString Token); - - /** Delegates for WebSocket events */ - UPROPERTY() - FOnWebSocketConnected OnConnected; - - /** Broadcast when a connection error occurs */ - UPROPERTY() - FOnWebSocketConnectionError OnConnectionError; - - /** Broadcast for text messages */ - UPROPERTY() - FOnWebSocketMessageReceived OnMessageReceived; - - /** Broadcast for binary payloads */ - UPROPERTY() - FOnWebSocketBinaryMessageReceived OnBinaryMessageReceived; - - /** Broadcast when the socket is closed */ - UPROPERTY() - FOnWebSocketClosed OnClosed; - -private: - /**Underlying WebSocket implementation */ - TSharedPtr WebSocket; - - /** Handler for successful connection */ - void HandleConnected(); - /** Handler for connection errors */ - void HandleConnectionError(const FString& Error); - /** Handler for incoming text messages */ - void HandleMessageReceived(const FString& Message); - /** Handler for incoming binary messages */ - void HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment); - /** Handler for socket close */ - void HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean); - - /** Decompresses a payload based on compression variant */ - bool DecompressPayload(ECompressableQueryUpdateTag Variant, const TArray& In, TArray& Out); - /** GZip decompression helper */ - bool DecompressGzip(const TArray& InData, TArray& OutData); - /** Brotli decompression helper */ - bool DecompressBrotli(const TArray& InData, TArray& OutData); - - FString InitToken; - - /** Buffer used to accumulate binary fragments until a complete message - * is received. */ - TArray IncompleteMessage; - - /** Tracks if we are waiting for additional binary fragments. */ - bool bAwaitingBinaryFragments = false; - -}; - -// Helper function to log a struct as JSON, expanding any transient objects -template -static void LogAsJson(const StructType& InStruct, const TCHAR* TagName) -{ - FString Json; - if (!FJsonObjectConverter::UStructToJsonObjectString(InStruct, Json)) - { - UE_LOG(LogTemp, Warning, TEXT("[%s] Failed to serialize to JSON"), TagName); - return; - } - - // Print original JSON - UE_LOG(LogTemp, Log, TEXT("[%s] %s"), TagName, *Json); - - // Extract object paths like /Script/SpacetimeDbSdk.CompressableQueryUpdateType'/Engine/Transient.CompressableQueryUpdateType_0' - const FRegexPattern Pattern(TEXT(R"((\/Script\/SpacetimeDbSdk\.\w+)'\/Engine\/Transient\.(\w+))")); - FRegexMatcher Matcher(Pattern, Json); - - while (Matcher.FindNext()) - { - FString ClassName = Matcher.GetCaptureGroup(1); // e.g., /Script/SpacetimeDbSdk.CompressableQueryUpdateType - FString ObjectName = Matcher.GetCaptureGroup(2); // e.g., CompressableQueryUpdateType_0 - - // Find the object in memory - UObject* FoundObj = StaticFindObject(UObject::StaticClass(), GetTransientPackage(), *ObjectName); - if (FoundObj == nullptr) - { - UE_LOG(LogTemp, Warning, TEXT("[%s] Could not find object: %s"), TagName, *ObjectName); - continue; - } - - // Log its expanded contents - FString SubJson; - if (FJsonObjectConverter::UStructToJsonObjectString( - static_cast(FoundObj->GetClass()), FoundObj, SubJson)) - { - UE_LOG(LogTemp, Log, TEXT("[%s] %s: %s"), TagName, *ObjectName, *SubJson); - } - else - { - UE_LOG(LogTemp, Warning, TEXT("[%s] Failed to serialize object: %s"), TagName, *ObjectName); - } - } +#pragma once + +#include "CoreMinimal.h" +#include "IWebSocket.h" +#include "ModuleBindings/Types/ServerMessageType.g.h" +#include "ModuleBindings/Types/CompressableQueryUpdateType.g.h" +#include "JsonObjectConverter.h" // for JSON debugging helpers +#include "Async/Async.h" +#include "HAL/CriticalSection.h" +#include "Misc/ScopeLock.h" +#include "LogCategory.h" + + +#include "Websocket.generated.h" + +/** Delegate broadcast when a connection is successfully established */ +DECLARE_DYNAMIC_MULTICAST_DELEGATE(FOnWebSocketConnected); +/** Delegate broadcast on connection error */ +DECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParam(FOnWebSocketConnectionError, const FString&, ErrorMessage); +/** Delegate broadcast when a text message is received */ +DECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParam(FOnWebSocketMessageReceived, const FString&, Message); +/** Delegate broadcast when the socket closes */ +DECLARE_DYNAMIC_MULTICAST_DELEGATE_ThreeParams(FOnWebSocketClosed, int32, StatusCode, const FString&, Reason, bool, bWasClean); +/** Delegate broadcast when binary data is received */ +DECLARE_DYNAMIC_MULTICAST_DELEGATE_OneParam(FOnWebSocketBinaryMessageReceived, const TArray&, Data); + + +/** + * Manages the low-level WebSocket connection to the SpacetimeDB server. + * Handles connecting, disconnecting, sending messages, and receiving messages. + */ +UCLASS(BlueprintType) +class SPACETIMEDBSDK_API UWebsocketManager : public UObject +{ + GENERATED_BODY() + +public: + UWebsocketManager(); + + virtual void BeginDestroy() override; + + /** + * Connects to the WebSocket server at the given URL. + * @param ServerUrl The URL of the WebSocket server. + */ + void Connect(const FString& ServerUrl); + + /** + * Disconnects from the WebSocket server. + */ + void Disconnect(); + + /** + * Sends a message to the WebSocket server. + * @param Message The message to send. + * @return True if the message was sent successfully, false otherwise. + */ + bool SendMessage(const FString& Message); + + /** + * Sends binary data to the WebSocket server. + * @param Data The bytes to send. + * @return True if the message was sent successfully, false otherwise. + */ + bool SendMessage(const TArray& Data); + + /** + * Checks if the WebSocket connection is currently active. + * @return True if connected, false otherwise. + */ + bool IsConnected() const; + + /** + * Sets the initial auth token used when connecting. + * @param Token JWT or session token expected by the server. + */ + void SetInitToken(FString Token); + + /** Delegates for WebSocket events */ + UPROPERTY() + FOnWebSocketConnected OnConnected; + + /** Broadcast when a connection error occurs */ + UPROPERTY() + FOnWebSocketConnectionError OnConnectionError; + + /** Broadcast for text messages */ + UPROPERTY() + FOnWebSocketMessageReceived OnMessageReceived; + + /** Broadcast for binary payloads */ + UPROPERTY() + FOnWebSocketBinaryMessageReceived OnBinaryMessageReceived; + + /** Broadcast when the socket is closed */ + UPROPERTY() + FOnWebSocketClosed OnClosed; + +private: + /**Underlying WebSocket implementation */ + TSharedPtr WebSocket; + + /** Handler for successful connection */ + void HandleConnected(); + /** Handler for connection errors */ + void HandleConnectionError(const FString& Error); + /** Handler for incoming text messages */ + void HandleMessageReceived(const FString& Message); + /** Handler for incoming binary messages */ + void HandleBinaryMessageReceived(const void* Data, SIZE_T Size, bool bIsLastFragment); + /** Handler for socket close */ + void HandleClosed(int32 StatusCode, const FString& Reason, bool bWasClean); + + /** Decompresses a payload based on compression variant */ + bool DecompressPayload(ECompressableQueryUpdateTag Variant, const TArray& In, TArray& Out); + /** GZip decompression helper */ + bool DecompressGzip(const TArray& InData, TArray& OutData); + /** Brotli decompression helper */ + bool DecompressBrotli(const TArray& InData, TArray& OutData); + + FString InitToken; + + /** Buffer used to accumulate binary fragments until a complete message + * is received. */ + TArray IncompleteMessage; + + /** Tracks if we are waiting for additional binary fragments. */ + bool bAwaitingBinaryFragments = false; + +}; + +// Helper function to log a struct as JSON, expanding any transient objects +template +static void LogAsJson(const StructType& InStruct, const TCHAR* TagName) +{ + FString Json; + if (!FJsonObjectConverter::UStructToJsonObjectString(InStruct, Json)) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("[%s] Failed to serialize to JSON"), TagName); + return; + } + + // Print original JSON + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("[%s] %s"), TagName, *Json); + + // Extract object paths like /Script/SpacetimeDbSdk.CompressableQueryUpdateType'/Engine/Transient.CompressableQueryUpdateType_0' + const FRegexPattern Pattern(TEXT(R"((\/Script\/SpacetimeDbSdk\.\w+)'\/Engine\/Transient\.(\w+))")); + FRegexMatcher Matcher(Pattern, Json); + + while (Matcher.FindNext()) + { + FString ClassName = Matcher.GetCaptureGroup(1); // e.g., /Script/SpacetimeDbSdk.CompressableQueryUpdateType + FString ObjectName = Matcher.GetCaptureGroup(2); // e.g., CompressableQueryUpdateType_0 + + // Find the object in memory + UObject* FoundObj = StaticFindObject(UObject::StaticClass(), GetTransientPackage(), *ObjectName); + if (FoundObj == nullptr) + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("[%s] Could not find object: %s"), TagName, *ObjectName); + continue; + } + + // Log its expanded contents + FString SubJson; + if (FJsonObjectConverter::UStructToJsonObjectString( + static_cast(FoundObj->GetClass()), FoundObj, SubJson)) + { + UE_LOG(LogSpacetimeDb_Connection, Log, TEXT("[%s] %s: %s"), TagName, *ObjectName, *SubJson); + } + else + { + UE_LOG(LogSpacetimeDb_Connection, Warning, TEXT("[%s] Failed to serialize object: %s"), TagName, *ObjectName); + } + } } \ No newline at end of file