diff --git a/cmd/api/src/analysis/azure/post.go b/cmd/api/src/analysis/azure/post.go index e101b6943d..c574cc0dec 100644 --- a/cmd/api/src/analysis/azure/post.go +++ b/cmd/api/src/analysis/azure/post.go @@ -29,7 +29,7 @@ import ( func Post(ctx context.Context, db graph.Database) (*analysis.AtomicPostProcessingStats, error) { aggregateStats := analysis.NewAtomicPostProcessingStats() - if stats, err := analysis.DeleteTransitEdges(ctx, db, graph.Kinds{ad.Entity, azure.Entity}, azureAnalysis.AzurePostProcessedRelationships()...); err != nil { + if stats, err := analysis.DeleteTransitEdges(ctx, db, graph.Kinds{ad.Entity, azure.Entity}, azureAnalysis.PostProcessedRelationships()...); err != nil { return &aggregateStats, err } else if userRoleStats, err := azureAnalysis.UserRoleAssignments(ctx, db); err != nil { return &aggregateStats, err diff --git a/cmd/api/src/api/saml/saml.go b/cmd/api/src/api/saml/saml.go index 436c09627a..e854de3d09 100644 --- a/cmd/api/src/api/saml/saml.go +++ b/cmd/api/src/api/saml/saml.go @@ -350,11 +350,11 @@ func (s ProviderResource) serveAssertionConsumerService(response http.ResponseWr s.writeAPIErrorResponse(request, response, http.StatusBadRequest, "session assertion does not meet the requirements for user lookup") } else { s.authenticator.CreateSSOSession(request, response, principalName, model.SSOProvider{ - Type: model.SessionAuthProviderSAML, - Name: s.serviceProvider.Config.Name, - Slug: s.serviceProvider.Config.Name, + Type: model.SessionAuthProviderSAML, + Name: s.serviceProvider.Config.Name, + Slug: s.serviceProvider.Config.Name, SAMLProvider: &s.serviceProvider.Config, - Serial: model.Serial{ ID: s.serviceProvider.Config.SSOProviderID.Int32 }, + Serial: model.Serial{ID: s.serviceProvider.Config.SSOProviderID.Int32}, }) } } diff --git a/cmd/api/src/api/saml/saml_internal_test.go b/cmd/api/src/api/saml/saml_internal_test.go index 94503131dd..a0a36cee9a 100644 --- a/cmd/api/src/api/saml/saml_internal_test.go +++ b/cmd/api/src/api/saml/saml_internal_test.go @@ -42,12 +42,12 @@ import ( func SSOProviderFromResource(resource ProviderResource) model.SSOProvider { return model.SSOProvider{ - Type: model.SessionAuthProviderSAML, - Name: resource.serviceProvider.Config.Name, - Slug: resource.serviceProvider.Config.Name, - SAMLProvider: &resource.serviceProvider.Config, - Serial: model.Serial{ ID: resource.serviceProvider.Config.SSOProviderID.Int32 }, - } + Type: model.SessionAuthProviderSAML, + Name: resource.serviceProvider.Config.Name, + Slug: resource.serviceProvider.Config.Name, + SAMLProvider: &resource.serviceProvider.Config, + Serial: model.Serial{ID: resource.serviceProvider.Config.SSOProviderID.Int32}, + } } func TestAuth_CreateSSOSession(t *testing.T) { @@ -58,7 +58,7 @@ func TestAuth_CreateSSOSession(t *testing.T) { SAMLProvider: &model.SAMLProvider{ Serial: model.Serial{ID: 1}, }, - SSOProviderID: null.Int32From(1), + SSOProviderID: null.Int32From(1), SAMLProviderID: null.Int32From(1), } @@ -71,7 +71,7 @@ func TestAuth_CreateSSOSession(t *testing.T) { config.Configuration{RootURL: serde.MustParseURL("https://example.com")}, bhsaml.ServiceProvider{ Config: model.SAMLProvider{ - Serial: model.Serial{ID: 1}, + Serial: model.Serial{ID: 1}, SSOProviderID: null.Int32From(1), }, }, diff --git a/cmd/api/src/api/v2/auth/oidc.go b/cmd/api/src/api/v2/auth/oidc.go index 825b4d91d1..67541408b5 100644 --- a/cmd/api/src/api/v2/auth/oidc.go +++ b/cmd/api/src/api/v2/auth/oidc.go @@ -74,7 +74,7 @@ func (s ManagementResource) OIDCLoginHandler(response http.ResponseWriter, reque ClientID: ssoProvider.OIDCProvider.ClientID, Endpoint: provider.Endpoint(), RedirectURL: getRedirectURL(request, ssoProvider), - Scopes: []string{"openid", "profile", "email", "email_verified", "name", "given_name", "family_name"}, + Scopes: []string{"openid", "profile", "email", "email_verified", "name", "given_name", "family_name"}, } // use PKCE to protect against CSRF attacks @@ -134,7 +134,7 @@ func (s ManagementResource) OIDCCallbackHandler(response http.ResponseWriter, re // Extract custom claims var claims struct { Name string `json:"name"` - FamilyName string `json:"family_name"` + FamilyName string `json:"family_name"` DisplayName string `json:"given_name"` Email string `json:"email"` Verified bool `json:"email_verified"` diff --git a/packages/go/analysis/ad/post.go b/packages/go/analysis/ad/post.go index 4c08959de2..06e15820f0 100644 --- a/packages/go/analysis/ad/post.go +++ b/packages/go/analysis/ad/post.go @@ -56,9 +56,11 @@ func PostProcessedRelationships() []graph.Kind { ad.ADCSESC10a, ad.ADCSESC10b, ad.ADCSESC9a, + ad.ADCSESC9b, ad.ADCSESC13, ad.EnrollOnBehalfOf, ad.SyncedToEntraUser, + ad.ExtendedByPolicy, } } diff --git a/packages/go/analysis/azure/post.go b/packages/go/analysis/azure/post.go index a6e3d2b20d..6239b60551 100644 --- a/packages/go/analysis/azure/post.go +++ b/packages/go/analysis/azure/post.go @@ -117,7 +117,7 @@ func PasswordAdministratorPasswordResetTargetRoles() []string { } } -func AzurePostProcessedRelationships() []graph.Kind { +func PostProcessedRelationships() []graph.Kind { return []graph.Kind{ azure.AddSecret, azure.ExecuteCommand, @@ -205,40 +205,47 @@ func AppRoleAssignments(ctx context.Context, db graph.Database) (*analysis.Atomi if tenants, err := FetchTenants(ctx, db); err != nil { return &analysis.AtomicPostProcessingStats{}, err } else { - operation := analysis.NewPostRelationshipOperation(ctx, db, "Azure App Role Assignments Post Processing") + var ( + operation = analysis.NewPostRelationshipOperation(ctx, db, "Azure App Role Assignments Post Processing") + edgeConstraintMap = analysis.NewEdgeConstraintMap() + ) + for _, tenant := range tenants { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if tenantContainsServicePrincipalRelationships, err := fetchTenantContainsRelationships(tx, tenant, azure.ServicePrincipal); err != nil { return err - } else if err := createAZMGApplicationReadWriteAllEdges(ctx, db, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGApplicationReadWriteAllEdges(ctx, db, edgeConstraintMap, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := createAZMGAppRoleAssignmentReadWriteAllEdges(ctx, db, operation, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGAppRoleAssignmentReadWriteAllEdges(ctx, db, edgeConstraintMap, operation, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := createAZMGDirectoryReadWriteAllEdges(ctx, db, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGDirectoryReadWriteAllEdges(ctx, db, edgeConstraintMap, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := createAZMGGroupReadWriteAllEdges(ctx, db, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGGroupReadWriteAllEdges(ctx, db, edgeConstraintMap, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := createAZMGGroupMemberReadWriteAllEdges(ctx, db, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGGroupMemberReadWriteAllEdges(ctx, db, edgeConstraintMap, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := createAZMGRoleManagementReadWriteDirectoryEdgesPart1(ctx, db, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGRoleManagementReadWriteDirectoryEdgesPart1(ctx, db, edgeConstraintMap, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := createAZMGRoleManagementReadWriteDirectoryEdgesPart2(ctx, db, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGRoleManagementReadWriteDirectoryEdgesPart2(ctx, db, edgeConstraintMap, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := createAZMGRoleManagementReadWriteDirectoryEdgesPart3(ctx, db, operation, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGRoleManagementReadWriteDirectoryEdgesPart3(ctx, db, edgeConstraintMap, operation, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := createAZMGRoleManagementReadWriteDirectoryEdgesPart4(ctx, db, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGRoleManagementReadWriteDirectoryEdgesPart4(ctx, db, edgeConstraintMap, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := createAZMGRoleManagementReadWriteDirectoryEdgesPart5(ctx, db, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGRoleManagementReadWriteDirectoryEdgesPart5(ctx, db, edgeConstraintMap, operation, tenant, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := createAZMGServicePrincipalEndpointReadWriteAllEdges(ctx, db, operation, tenantContainsServicePrincipalRelationships); err != nil { + } else if err := createAZMGServicePrincipalEndpointReadWriteAllEdges(ctx, db, edgeConstraintMap, operation, tenantContainsServicePrincipalRelationships); err != nil { return err - } else if err := addSecret(ctx, db, operation, tenant); err != nil { + } else if err := addSecret(edgeConstraintMap, operation, tenant); err != nil { return err } return nil }); err != nil { - operation.Done() + if err := operation.Done(); err != nil { + log.Errorf("Error caught during azure AppRoleAssignments teardown: %v", err) + } + return &operation.Stats, err } } @@ -246,17 +253,15 @@ func AppRoleAssignments(ctx context.Context, db graph.Database) (*analysis.Atomi } } -func createAZMGApplicationReadWriteAllEdges(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGApplicationReadWriteAllEdges(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if tenantContainsAppRelationships, err := fetchTenantContainsRelationships(tx, tenant, azure.App); err != nil { return err } else if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.ApplicationReadWriteAll); err != nil { return err } else { - targetRelationships := append(tenantContainsServicePrincipalRelationships, tenantContainsAppRelationships...) - - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { - for _, targetRelationship := range targetRelationships { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + for _, targetRelationship := range append(tenantContainsServicePrincipalRelationships, tenantContainsAppRelationships...) { for _, sourceNode := range sourceNodes { AZMGAddSecretRelationship := analysis.CreatePostRelationshipJob{ FromID: sourceNode.ID, @@ -264,7 +269,7 @@ func createAZMGApplicationReadWriteAllEdges(ctx context.Context, db graph.Databa Kind: azure.AZMGAddSecret, } - if !channels.Submit(ctx, outC, AZMGAddSecretRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddSecretRelationship) && !channels.Submit(ctx, outC, AZMGAddSecretRelationship) { return nil } @@ -274,15 +279,15 @@ func createAZMGApplicationReadWriteAllEdges(ctx context.Context, db graph.Databa Kind: azure.AZMGAddOwner, } - if !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddOwnerRelationship) && !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { return nil } } } + return nil }) } - return nil }); err != nil { return err } else { @@ -290,12 +295,12 @@ func createAZMGApplicationReadWriteAllEdges(ctx context.Context, db graph.Databa } } -func createAZMGAppRoleAssignmentReadWriteAllEdges(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGAppRoleAssignmentReadWriteAllEdges(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.AppRoleAssignmentReadWriteAll); err != nil { return err } else { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { for _, tenantContainsServicePrincipalRelationship := range tenantContainsServicePrincipalRelationships { for _, sourceNode := range sourceNodes { AZMGGrantAppRolesRelationship := analysis.CreatePostRelationshipJob{ @@ -304,7 +309,7 @@ func createAZMGAppRoleAssignmentReadWriteAllEdges(ctx context.Context, db graph. Kind: azure.AZMGGrantAppRoles, } - if !channels.Submit(ctx, outC, AZMGGrantAppRolesRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGGrantAppRolesRelationship) && !channels.Submit(ctx, outC, AZMGGrantAppRolesRelationship) { return nil } } @@ -313,7 +318,6 @@ func createAZMGAppRoleAssignmentReadWriteAllEdges(ctx context.Context, db graph. return nil }) } - return nil }); err != nil { return err } else { @@ -321,14 +325,14 @@ func createAZMGAppRoleAssignmentReadWriteAllEdges(ctx context.Context, db graph. } } -func createAZMGDirectoryReadWriteAllEdges(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGDirectoryReadWriteAllEdges(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.DirectoryReadWriteAll); err != nil { return err } else if tenantContainsGroupRelationships, err := fetchTenantContainsReadWriteAllGroupRelationships(tx, tenant); err != nil { return err } else { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { for _, tenantContainsGroupRelationship := range tenantContainsGroupRelationships { for _, sourceNode := range sourceNodes { AZMGAddMemberRelationship := analysis.CreatePostRelationshipJob{ @@ -337,7 +341,7 @@ func createAZMGDirectoryReadWriteAllEdges(ctx context.Context, db graph.Database Kind: azure.AZMGAddMember, } - if !channels.Submit(ctx, outC, AZMGAddMemberRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddMemberRelationship) && !channels.Submit(ctx, outC, AZMGAddMemberRelationship) { return nil } @@ -347,7 +351,7 @@ func createAZMGDirectoryReadWriteAllEdges(ctx context.Context, db graph.Database Kind: azure.AZMGAddOwner, } - if !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddOwnerRelationship) && !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { return nil } } @@ -355,7 +359,6 @@ func createAZMGDirectoryReadWriteAllEdges(ctx context.Context, db graph.Database return nil }) } - return nil }); err != nil { return err } else { @@ -363,14 +366,14 @@ func createAZMGDirectoryReadWriteAllEdges(ctx context.Context, db graph.Database } } -func createAZMGGroupReadWriteAllEdges(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGGroupReadWriteAllEdges(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.GroupReadWriteAll); err != nil { return err } else if tenantContainsGroupRelationships, err := fetchTenantContainsReadWriteAllGroupRelationships(tx, tenant); err != nil { return err } else { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { for _, tenantContainsGroupRelationship := range tenantContainsGroupRelationships { for _, sourceNode := range sourceNodes { AZMGAddMemberRelationship := analysis.CreatePostRelationshipJob{ @@ -379,7 +382,7 @@ func createAZMGGroupReadWriteAllEdges(ctx context.Context, db graph.Database, op Kind: azure.AZMGAddMember, } - if !channels.Submit(ctx, outC, AZMGAddMemberRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddMemberRelationship) && !channels.Submit(ctx, outC, AZMGAddMemberRelationship) { return nil } @@ -389,7 +392,7 @@ func createAZMGGroupReadWriteAllEdges(ctx context.Context, db graph.Database, op Kind: azure.AZMGAddOwner, } - if !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddOwnerRelationship) && !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { return nil } } @@ -397,7 +400,6 @@ func createAZMGGroupReadWriteAllEdges(ctx context.Context, db graph.Database, op return nil }) } - return nil }); err != nil { return err } else { @@ -405,14 +407,14 @@ func createAZMGGroupReadWriteAllEdges(ctx context.Context, db graph.Database, op } } -func createAZMGGroupMemberReadWriteAllEdges(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGGroupMemberReadWriteAllEdges(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.GroupMemberReadWriteAll); err != nil { return err } else if tenantContainsGroupRelationships, err := fetchTenantContainsReadWriteAllGroupRelationships(tx, tenant); err != nil { return err } else { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { for _, tenantContainsGroupRelationship := range tenantContainsGroupRelationships { for _, sourceNode := range sourceNodes { AZMGAddMemberRelationship := analysis.CreatePostRelationshipJob{ @@ -421,7 +423,7 @@ func createAZMGGroupMemberReadWriteAllEdges(ctx context.Context, db graph.Databa Kind: azure.AZMGAddMember, } - if !channels.Submit(ctx, outC, AZMGAddMemberRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddMemberRelationship) && !channels.Submit(ctx, outC, AZMGAddMemberRelationship) { return nil } } @@ -429,7 +431,6 @@ func createAZMGGroupMemberReadWriteAllEdges(ctx context.Context, db graph.Databa return nil }) } - return nil }); err != nil { return err } else { @@ -437,14 +438,14 @@ func createAZMGGroupMemberReadWriteAllEdges(ctx context.Context, db graph.Databa } } -func createAZMGRoleManagementReadWriteDirectoryEdgesPart1(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGRoleManagementReadWriteDirectoryEdgesPart1(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.RoleManagementReadWriteDirectory); err != nil { return err } else if tenantContainsRoleRelationships, err := fetchTenantContainsRelationships(tx, tenant, azure.Role); err != nil { return err } else { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { for _, tenantContainsRoleRelationship := range tenantContainsRoleRelationships { for _, sourceNode := range sourceNodes { AZMGGrantAppRolesRelationship := analysis.CreatePostRelationshipJob{ @@ -453,7 +454,7 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart1(ctx context.Context, d Kind: azure.AZMGGrantAppRoles, } - if !channels.Submit(ctx, outC, AZMGGrantAppRolesRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGGrantAppRolesRelationship) && !channels.Submit(ctx, outC, AZMGGrantAppRolesRelationship) { return nil } } @@ -461,7 +462,6 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart1(ctx context.Context, d return nil }) } - return nil }); err != nil { return err } else { @@ -469,14 +469,14 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart1(ctx context.Context, d } } -func createAZMGRoleManagementReadWriteDirectoryEdgesPart2(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGRoleManagementReadWriteDirectoryEdgesPart2(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.RoleManagementReadWriteDirectory); err != nil { return err } else if tenantContainsRoleRelationships, err := fetchTenantContainsRelationships(tx, tenant, azure.Role); err != nil { return err } else { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { for _, tenantContainsRoleRelationship := range tenantContainsRoleRelationships { for _, sourceNode := range sourceNodes { AZMGGrantRoleRelationship := analysis.CreatePostRelationshipJob{ @@ -485,7 +485,7 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart2(ctx context.Context, d Kind: azure.AZMGGrantRole, } - if !channels.Submit(ctx, outC, AZMGGrantRoleRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGGrantRoleRelationship) && !channels.Submit(ctx, outC, AZMGGrantRoleRelationship) { return nil } } @@ -493,7 +493,6 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart2(ctx context.Context, d return nil }) } - return nil }); err != nil { return err } else { @@ -501,12 +500,12 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart2(ctx context.Context, d } } -func createAZMGRoleManagementReadWriteDirectoryEdgesPart3(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGRoleManagementReadWriteDirectoryEdgesPart3(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.RoleManagementReadWriteDirectory); err != nil { return err } else { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { for _, tenantContainsServicePrincipalRelationship := range tenantContainsServicePrincipalRelationships { for _, sourceNode := range sourceNodes { AZMGAddSecretRelationship := analysis.CreatePostRelationshipJob{ @@ -515,7 +514,7 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart3(ctx context.Context, d Kind: azure.AZMGAddSecret, } - if !channels.Submit(ctx, outC, AZMGAddSecretRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddSecretRelationship) && !channels.Submit(ctx, outC, AZMGAddSecretRelationship) { return nil } @@ -525,15 +524,15 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart3(ctx context.Context, d Kind: azure.AZMGAddOwner, } - if !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddOwnerRelationship) && !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { return nil } } } + return nil }) } - return nil }); err != nil { return err } else { @@ -541,14 +540,14 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart3(ctx context.Context, d } } -func createAZMGRoleManagementReadWriteDirectoryEdgesPart4(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGRoleManagementReadWriteDirectoryEdgesPart4(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.RoleManagementReadWriteDirectory); err != nil { return err } else if tenantContainsAppRelationships, err := fetchTenantContainsRelationships(tx, tenant, azure.App); err != nil { return err } else { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { for _, tenantContainsAppRelationship := range tenantContainsAppRelationships { for _, sourceNode := range sourceNodes { AZMGAddSecretRelationship := analysis.CreatePostRelationshipJob{ @@ -557,7 +556,7 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart4(ctx context.Context, d Kind: azure.AZMGAddSecret, } - if !channels.Submit(ctx, outC, AZMGAddSecretRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddSecretRelationship) && !channels.Submit(ctx, outC, AZMGAddSecretRelationship) { return nil } @@ -567,15 +566,15 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart4(ctx context.Context, d Kind: azure.AZMGAddOwner, } - if !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddOwnerRelationship) && !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { return nil } } } + return nil }) } - return nil }); err != nil { return err } else { @@ -583,14 +582,14 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart4(ctx context.Context, d } } -func createAZMGRoleManagementReadWriteDirectoryEdgesPart5(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGRoleManagementReadWriteDirectoryEdgesPart5(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.RoleManagementReadWriteDirectory); err != nil { return err } else if tenantContainsGroupRelationships, err := fetchTenantContainsRelationships(tx, tenant, azure.Group); err != nil { return err } else { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { for _, tenantContainsGroupRelationship := range tenantContainsGroupRelationships { for _, sourceNode := range sourceNodes { AZMGAddMemberRelationship := analysis.CreatePostRelationshipJob{ @@ -599,7 +598,7 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart5(ctx context.Context, d Kind: azure.AZMGAddMember, } - if !channels.Submit(ctx, outC, AZMGAddMemberRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddMemberRelationship) && !channels.Submit(ctx, outC, AZMGAddMemberRelationship) { return nil } @@ -609,7 +608,7 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart5(ctx context.Context, d Kind: azure.AZMGAddOwner, } - if !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddOwnerRelationship) && !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { return nil } } @@ -617,7 +616,6 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart5(ctx context.Context, d return nil }) } - return nil }); err != nil { return err } else { @@ -625,12 +623,12 @@ func createAZMGRoleManagementReadWriteDirectoryEdgesPart5(ctx context.Context, d } } -func createAZMGServicePrincipalEndpointReadWriteAllEdges(ctx context.Context, db graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenantContainsServicePrincipalRelationships []*graph.Relationship) error { +func createAZMGServicePrincipalEndpointReadWriteAllEdges(ctx context.Context, db graph.Database, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenantContainsServicePrincipalRelationships []*graph.Relationship) error { if err := db.ReadTransaction(ctx, func(tx graph.Transaction) error { if sourceNodes, err := aggregateSourceReadWriteServicePrincipals(tx, tenantContainsServicePrincipalRelationships, azure.ServicePrincipalEndpointReadWriteAll); err != nil { return err } else { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { for _, tenantContainsServicePrincipalRelationship := range tenantContainsServicePrincipalRelationships { for _, sourceNode := range sourceNodes { AZMGAddOwnerRelationship := analysis.CreatePostRelationshipJob{ @@ -639,15 +637,15 @@ func createAZMGServicePrincipalEndpointReadWriteAllEdges(ctx context.Context, db Kind: azure.AZMGAddOwner, } - if !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(AZMGAddOwnerRelationship) && !channels.Submit(ctx, outC, AZMGAddOwnerRelationship) { return nil } } } + return nil }) } - return nil }); err != nil { return err } else { @@ -655,7 +653,7 @@ func createAZMGServicePrincipalEndpointReadWriteAllEdges(ctx context.Context, db } } -func addSecret(_ context.Context, _ graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node) error { +func addSecret(edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node) error { return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { if addSecretRoles, err := TenantRoles(tx, tenant, AddSecretRoleIDs()...); err != nil { return err @@ -671,7 +669,7 @@ func addSecret(_ context.Context, _ graph.Database, operation analysis.StatTrack Kind: azure.AddSecret, } - if !channels.Submit(ctx, outC, nextJob) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(nextJob) && !channels.Submit(ctx, outC, nextJob) { return nil } } @@ -698,7 +696,8 @@ func ExecuteCommand(ctx context.Context, db graph.Database) (*analysis.AtomicPos } else { for _, tenantDevice := range tenantDevices { innerTenantDevice := tenantDevice - operation.Operation.SubmitReader(func(ctx context.Context, _ graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + + if err := operation.Operation.SubmitReader(func(ctx context.Context, _ graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { if isWindowsDevice, err := IsWindowsDevice(innerTenantDevice); err != nil { return err } else if isWindowsDevice { @@ -716,14 +715,19 @@ func ExecuteCommand(ctx context.Context, db graph.Database) (*analysis.AtomicPos } return nil - }) + }); err != nil { + return err + } } } } return nil }); err != nil { - operation.Done() + if err := operation.Done(); err != nil { + log.Errorf("Error caught during azure ExecuteCommand teardown: %v", err) + } + return &operation.Stats, err } @@ -731,7 +735,7 @@ func ExecuteCommand(ctx context.Context, db graph.Database) (*analysis.AtomicPos } } -func resetPassword(_ context.Context, _ graph.Database, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, roleAssignments RoleAssignments) error { +func resetPassword(operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob], tenant *graph.Node, roleAssignments RoleAssignments) error { return operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { if pwResetRoles, err := TenantRoles(tx, tenant, ResetPasswordRoleIDs()...); err != nil { return err @@ -787,12 +791,13 @@ func resetPasswordEndNodeBitmapForRole(role *graph.Node, roleAssignments RoleAss default: return nil, fmt.Errorf("role node %d has unsupported role template id '%s'", role.ID, roleTemplateID) } + return result, nil } } -func globalAdmins(roleAssignments RoleAssignments, tenant *graph.Node, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { +func globalAdmins(roleAssignments RoleAssignments, edgeConstraintMap analysis.EdgeConstraintMap, tenant *graph.Node, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) { + if err := operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { roleAssignments.PrincipalsWithRole(azure.CompanyAdministratorRole).Each(func(nextID uint64) bool { nextJob := analysis.CreatePostRelationshipJob{ FromID: graph.ID(nextID), @@ -800,7 +805,7 @@ func globalAdmins(roleAssignments RoleAssignments, tenant *graph.Node, operation Kind: azure.GlobalAdmin, } - if !channels.Submit(ctx, outC, nextJob) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(nextJob) && !channels.Submit(ctx, outC, nextJob) { return false } @@ -808,11 +813,13 @@ func globalAdmins(roleAssignments RoleAssignments, tenant *graph.Node, operation }) return nil - }) + }); err != nil { + log.Errorf("Failed to submit azure global admins post processing job: %v", err) + } } -func privilegedRoleAdmins(roleAssignments RoleAssignments, tenant *graph.Node, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { +func privilegedRoleAdmins(roleAssignments RoleAssignments, edgeConstraintMap analysis.EdgeConstraintMap, tenant *graph.Node, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) { + if err := operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { roleAssignments.PrincipalsWithRole(azure.PrivilegedRoleAdministratorRole).Each(func(nextID uint64) bool { nextJob := analysis.CreatePostRelationshipJob{ FromID: graph.ID(nextID), @@ -820,7 +827,7 @@ func privilegedRoleAdmins(roleAssignments RoleAssignments, tenant *graph.Node, o Kind: azure.PrivilegedRoleAdmin, } - if !channels.Submit(ctx, outC, nextJob) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(nextJob) && !channels.Submit(ctx, outC, nextJob) { return false } @@ -828,11 +835,13 @@ func privilegedRoleAdmins(roleAssignments RoleAssignments, tenant *graph.Node, o }) return nil - }) + }); err != nil { + log.Errorf("Failed to submit privileged role admins post processing job: %v", err) + } } -func privilegedAuthAdmins(roleAssignments RoleAssignments, tenant *graph.Node, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) { - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { +func privilegedAuthAdmins(roleAssignments RoleAssignments, edgeConstraintMap analysis.EdgeConstraintMap, tenant *graph.Node, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) { + if err := operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { roleAssignments.PrincipalsWithRole(azure.PrivilegedAuthenticationAdministratorRole).Each(func(nextID uint64) bool { nextJob := analysis.CreatePostRelationshipJob{ FromID: graph.ID(nextID), @@ -840,7 +849,7 @@ func privilegedAuthAdmins(roleAssignments RoleAssignments, tenant *graph.Node, o Kind: azure.PrivilegedAuthAdmin, } - if !channels.Submit(ctx, outC, nextJob) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(nextJob) && !channels.Submit(ctx, outC, nextJob) { return false } @@ -848,16 +857,19 @@ func privilegedAuthAdmins(roleAssignments RoleAssignments, tenant *graph.Node, o }) return nil - }) + }); err != nil { + log.Errorf("Failed to submit azure privileged auth admins post processing job: %v", err) + } } -func addMembers(roleAssignments RoleAssignments, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) { - tenantGroups := roleAssignments.Principals.Get(azure.Group) +func addMembers(roleAssignments RoleAssignments, edgeConstraintMap analysis.EdgeConstraintMap, operation analysis.StatTrackedOperation[analysis.CreatePostRelationshipJob]) { + for tenantGroupID, tenantGroup := range roleAssignments.Principals.Get(azure.Group) { + var ( + innerGroupID = tenantGroupID + innerGroup = tenantGroup + ) - for tenantGroupID, tenantGroup := range tenantGroups { - innerGroupID := tenantGroupID - innerGroup := tenantGroup - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + if err := operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { roleAssignments.UsersWithRole(AddMemberAllGroupsTargetRoles()...).Each(func(nextID uint64) bool { nextJob := analysis.CreatePostRelationshipJob{ FromID: graph.ID(nextID), @@ -865,7 +877,7 @@ func addMembers(roleAssignments RoleAssignments, operation analysis.StatTrackedO Kind: azure.AddMembers, } - if !channels.Submit(ctx, outC, nextJob) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(nextJob) && !channels.Submit(ctx, outC, nextJob) { return false } @@ -873,9 +885,11 @@ func addMembers(roleAssignments RoleAssignments, operation analysis.StatTrackedO }) return nil - }) + }); err != nil { + log.Errorf("Failed to submit azure add members AddMemberAllGroupsTargetRoles post processing job: %v", err) + } - operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { + if err := operation.Operation.SubmitReader(func(ctx context.Context, tx graph.Transaction, outC chan<- analysis.CreatePostRelationshipJob) error { if isRoleAssignable, err := innerGroup.Properties.Get(azure.IsAssignableToRole.String()).Bool(); err != nil { if graph.IsErrPropertyNotFound(err) { log.Warnf("Node %d is missing property %s", innerGroup.ID, azure.IsAssignableToRole) @@ -890,7 +904,7 @@ func addMembers(roleAssignments RoleAssignments, operation analysis.StatTrackedO Kind: azure.AddMembers, } - if !channels.Submit(ctx, outC, nextJob) { + if edgeConstraintMap.TrackCreatePostRelationshipJob(nextJob) && !channels.Submit(ctx, outC, nextJob) { return false } @@ -899,7 +913,9 @@ func addMembers(roleAssignments RoleAssignments, operation analysis.StatTrackedO } return nil - }) + }); err != nil { + log.Errorf("Failed to submit azure add members AddMemberGroupNotRoleAssignableTargetRoles post processing job: %v", err) + } } } @@ -907,20 +923,30 @@ func UserRoleAssignments(ctx context.Context, db graph.Database) (*analysis.Atom if tenantNodes, err := FetchTenants(ctx, db); err != nil { return &analysis.AtomicPostProcessingStats{}, err } else { - operation := analysis.NewPostRelationshipOperation(ctx, db, "Azure User Role Assignments Post Processing") + var ( + operation = analysis.NewPostRelationshipOperation(ctx, db, "Azure User Role Assignments Post Processing") + edgeConstraintMap = analysis.NewEdgeConstraintMap() + ) + for _, tenant := range tenantNodes { if roleAssignments, err := TenantRoleAssignments(ctx, db, tenant); err != nil { - operation.Done() + if err := operation.Done(); err != nil { + log.Errorf("Error caught during azure UserRoleAssignments.TenantRoleAssignments teardown: %v", err) + } + return &analysis.AtomicPostProcessingStats{}, err } else { - if err := resetPassword(ctx, db, operation, tenant, roleAssignments); err != nil { - operation.Done() + if err := resetPassword(operation, tenant, roleAssignments); err != nil { + if err := operation.Done(); err != nil { + log.Errorf("Error caught during azure UserRoleAssignments.resetPassword teardown: %v", err) + } + return &analysis.AtomicPostProcessingStats{}, err } else { - globalAdmins(roleAssignments, tenant, operation) - privilegedRoleAdmins(roleAssignments, tenant, operation) - privilegedAuthAdmins(roleAssignments, tenant, operation) - addMembers(roleAssignments, operation) + globalAdmins(roleAssignments, edgeConstraintMap, tenant, operation) + privilegedRoleAdmins(roleAssignments, edgeConstraintMap, tenant, operation) + privilegedAuthAdmins(roleAssignments, edgeConstraintMap, tenant, operation) + addMembers(roleAssignments, edgeConstraintMap, operation) } } } diff --git a/packages/go/analysis/post.go b/packages/go/analysis/post.go index 16f0a8de79..cba34105ee 100644 --- a/packages/go/analysis/post.go +++ b/packages/go/analysis/post.go @@ -19,6 +19,9 @@ package analysis import ( "context" "sort" + "sync" + + "github.com/specterops/bloodhound/dawgs/cardinality" "github.com/specterops/bloodhound/dawgs/graph" "github.com/specterops/bloodhound/dawgs/ops" @@ -120,6 +123,57 @@ type DeleteRelationshipJob struct { ID graph.ID } +// EdgeConstraintMap is a thread safe tracker for post-processed edges. It guarantees that only one edge of a given +// post-processed kind may exist. This is useful either to create a batch of post-processed edges to insert or to +// guard against double insertion of the same post-processed edge. +type EdgeConstraintMap struct { + lock *sync.Mutex + adjacent map[graph.ID]map[graph.Kind]cardinality.Duplex[uint64] +} + +func NewEdgeConstraintMap() EdgeConstraintMap { + return EdgeConstraintMap{ + lock: &sync.Mutex{}, + adjacent: map[graph.ID]map[graph.Kind]cardinality.Duplex[uint64]{}, + } +} + +// TrackCreatePostRelationshipJob decomposes a CreatePostRelationshipJob type and returns the result of tracking it. +func (s EdgeConstraintMap) TrackCreatePostRelationshipJob(job CreatePostRelationshipJob) bool { + return s.Track(job.FromID, job.ToID, job.Kind) +} + +// Track will attempt to track creation of the given relationship arguments. This function returns false +// if the given relationship has already been tracked; true otherwise. +func (s EdgeConstraintMap) Track(start, end graph.ID, edgeKind graph.Kind) bool { + s.lock.Lock() + defer s.lock.Unlock() + + // Lookup what's adjacent outbound from the start ID + if startAdjacent, exists := s.adjacent[start]; !exists { + // If there's nothing adjacent for the start ID then this is the first outbound edge being created + // for it. + s.adjacent[start] = map[graph.Kind]cardinality.Duplex[uint64]{ + edgeKind: cardinality.NewBitmap64With(end.Uint64()), + } + } else if kindAdjacent, exists := startAdjacent[edgeKind]; !exists { + // If there's no bitmap representing outbound edges over the given edge kind then create a new bitmap + // and track it. + startAdjacent[edgeKind] = cardinality.NewBitmap64With(end.Uint64()) + } else if !kindAdjacent.CheckedAdd(end.Uint64()) { + // This Debugf statement is here to help engineers figure out where double-inserts are coming from in + // the post-processing logic. + log.Debugf("Duplicate post-processed edge: (%d)-[:%s]->(%d)", start, edgeKind, end) + + // If the CheckedAdd function returns false, the outbound nodes already contains an entry for this end + // ID, meaning that the edge has already been created. + return false + } + + // Getting here means we have added a new, unique edge. + return true +} + func DeleteTransitEdges(ctx context.Context, db graph.Database, baseKinds graph.Kinds, targetRelationships ...graph.Kind) (*AtomicPostProcessingStats, error) { defer log.Measure(log.LevelInfo, "Finished deleting transit edges")() diff --git a/packages/go/cypher/models/pgsql/operators.go b/packages/go/cypher/models/pgsql/operators.go index 2ae2fac6b2..2f0010285d 100644 --- a/packages/go/cypher/models/pgsql/operators.go +++ b/packages/go/cypher/models/pgsql/operators.go @@ -98,6 +98,8 @@ const ( OperatorIn Operator = "in" OperatorIs Operator = "is" OperatorIsNot Operator = "is not" + OperatorSimilarTo Operator = "similar to" + OperatorRegexMatch Operator = "=~" OperatorStartsWith Operator = "starts with" OperatorContains Operator = "contains" OperatorEndsWith Operator = "ends with" diff --git a/packages/go/cypher/models/pgsql/pgtypes.go b/packages/go/cypher/models/pgsql/pgtypes.go index 684bfc9a0d..de4a12ba6f 100644 --- a/packages/go/cypher/models/pgsql/pgtypes.go +++ b/packages/go/cypher/models/pgsql/pgtypes.go @@ -294,22 +294,34 @@ func ValueToDataType(value any) (DataType, error) { case time.Duration: return Interval, nil + // * uint8 is here since it can't fit in a signed byte and therefore must coerce into a higher sized type case uint8, int8, int16: return Int2, nil + // * uint8 is here since it can't fit in a signed byte and therefore must coerce into a higher sized type case []uint8, []int8, []int16: return Int2Array, nil - case uint16, int32, graph.ID: + // * uint16 is here since it can't fit in a signed 16-bit value and therefore must coerce into a higher sized type + case uint16, int32: return Int4, nil - case []uint16, []int32, []graph.ID: + // * uint16 is here since it can't fit in a signed 16-bit value and therefore must coerce into a higher sized type + case []uint16, []int32: return Int4Array, nil - case uint32, uint, uint64, int, int64: + // * uint32 is here since it can't fit in a signed 16-bit value and therefore must coerce into a higher sized type + // * uint is here because it is architecture dependent but expecting it to be an unsigned value between 32-bits and + // 64-bits is fine. + // * int is here for the same reasons as uint + case uint32, uint, uint64, int, int64, graph.ID: return Int8, nil - case []uint32, []uint, []uint64, []int, []int64: + // * uint32 is here since it can't fit in a signed 16-bit value and therefore must coerce into a higher sized type + // * uint is here because it is architecture dependent but expecting it to be an unsigned value between 32-bits and + // 64-bits is fine. + // * int is here for the same reasons as uint + case []uint32, []uint, []uint64, []int, []int64, []graph.ID: return Int8Array, nil case float32: diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql index a55a06350b..e49a4e0e84 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_expansion.sql @@ -426,3 +426,175 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat where ex0.satisfied) select edges_to_path(variadic ep0)::pathcomposite as p from s0; + +-- case: match p = (:NodeKind1)<-[:EdgeKind1|EdgeKind2*..]-() return p limit 10 +with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, path) as (select e0.start_id, + e0.end_id, + 1, + false, + e0.start_id = e0.end_id, + array [e0.id] + from edge e0 + join node n0 on + n0.kind_ids operator (pg_catalog.&&) + array [1]::int2[] and + n0.id = e0.end_id + join node n1 on n1.id = e0.start_id + where e0.kind_id = any (array [11, 12]::int2[]) + union + select ex0.root_id, + e0.end_id, + ex0.depth + 1, + false, + e0.id = any (ex0.path), + ex0.path || e0.id + from ex0 + join edge e0 on e0.start_id = ex0.next_id + join node n1 on n1.id = e0.start_id + where ex0.depth < 5 + and not ex0.is_cycle + and e0.kind_id = any (array [11, 12]::int2[])) + select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) + from edge e0 + where e0.id = any (ex0.path)) as e0, + ex0.path as ep0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from ex0 + join edge e0 on e0.id = any (ex0.path) + join node n0 on n0.id = ex0.root_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.start_id) +select edges_to_path(variadic ep0)::pathcomposite as p +from s0 +limit 10; + + +-- case: match p = (:NodeKind1)<-[:EdgeKind1|EdgeKind2*..]-(:NodeKind2)<-[:EdgeKind1|EdgeKind2*..]-(:NodeKind1) return p limit 10 +with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, path) as (select e0.start_id, + e0.end_id, + 1, + n1.kind_ids operator (pg_catalog.&&) array [2]::int2[], + e0.start_id = e0.end_id, + array [e0.id] + from edge e0 + join node n0 on + n0.kind_ids operator (pg_catalog.&&) + array [1]::int2[] and + n0.id = e0.end_id + join node n1 on n1.id = e0.start_id + where e0.kind_id = any (array [11, 12]::int2[]) + union + select ex0.root_id, + e0.end_id, + ex0.depth + 1, + n1.kind_ids operator (pg_catalog.&&) array [2]::int2[], + e0.id = any (ex0.path), + ex0.path || e0.id + from ex0 + join edge e0 on e0.start_id = ex0.next_id + join node n1 on n1.id = e0.start_id + where ex0.depth < 5 + and not ex0.is_cycle + and e0.kind_id = any (array [11, 12]::int2[])) + select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) + from edge e0 + where e0.id = any (ex0.path)) as e0, + ex0.path as ep0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from ex0 + join edge e0 on e0.id = any (ex0.path) + join node n0 on n0.id = ex0.root_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.start_id + where ex0.satisfied), + s1 as (with recursive ex1(root_id, next_id, depth, satisfied, is_cycle, path) as (select e1.start_id, + e1.end_id, + 1, + false, + e1.start_id = e1.end_id, + array [e1.id] + from s0 + join edge e1 on + e1.kind_id = any + (array [11, 12]::int2[]) and + (s0.n1).id = e1.end_id + join node n2 on n2.id = e1.start_id + union + select ex1.root_id, + e1.end_id, + ex1.depth + 1, + n2.kind_ids operator (pg_catalog.&&) array [1]::int2[], + e1.id = any (ex1.path), + ex1.path || e1.id + from ex1 + join edge e1 on e1.start_id = ex1.next_id + join node n2 on n2.id = e1.start_id + where ex1.depth < 5 + and not ex1.is_cycle) + select s0.e0 as e0, + s0.ep0 as ep0, + s0.n0 as n0, + s0.n1 as n1, + (select array_agg((e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite) + from edge e1 + where e1.id = any (ex1.path)) as e1, + ex1.path as ep1, + (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 + from s0, + ex1 + join edge e1 on e1.id = any (ex1.path) + join node n1 on n1.id = ex1.root_id + join node n2 on e1.id = ex1.path[array_length(ex1.path, 1)::int4] and n2.id = e1.start_id) +select edges_to_path(variadic s1.ep1 || s1.ep0)::pathcomposite as p +from s1 +limit 10; + +-- case: match p = (n:NodeKind1)-[:EdgeKind1|EdgeKind2*1..2]->(r:NodeKind2) where r.name =~ '(?i)Global Administrator.*|User Administrator.*|Cloud Application Administrator.*|Authentication Policy Administrator.*|Exchange Administrator.*|Helpdesk Administrator.*|Privileged Authentication Administrator.*' return p limit 10 +with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, path) as (select e0.start_id, + e0.end_id, + 1, + n1.kind_ids operator (pg_catalog.&&) + array [2]::int2[] and + n1.properties ->> + 'name' similar to + '(?i)Global Administrator.*|User Administrator.*|Cloud Application Administrator.*|Authentication Policy Administrator.*|Exchange Administrator.*|Helpdesk Administrator.*|Privileged Authentication Administrator.*', + e0.start_id = e0.end_id, + array [e0.id] + from edge e0 + join node n0 on + n0.kind_ids operator (pg_catalog.&&) + array [1]::int2[] and + n0.id = e0.start_id + join node n1 on n1.id = e0.end_id + where e0.kind_id = any (array [11, 12]::int2[]) + union + select ex0.root_id, + e0.end_id, + ex0.depth + 1, + n1.kind_ids operator (pg_catalog.&&) + array [2]::int2[] and + n1.properties ->> + 'name' similar to + '(?i)Global Administrator.*|User Administrator.*|Cloud Application Administrator.*|Authentication Policy Administrator.*|Exchange Administrator.*|Helpdesk Administrator.*|Privileged Authentication Administrator.*', + e0.id = any (ex0.path), + ex0.path || e0.id + from ex0 + join edge e0 on e0.start_id = ex0.next_id + join node n1 on n1.id = e0.end_id + where ex0.depth < 5 + and not ex0.is_cycle + and e0.kind_id = any (array [11, 12]::int2[])) + select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (select array_agg((e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite) + from edge e0 + where e0.id = any (ex0.path)) as e0, + ex0.path as ep0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from ex0 + join edge e0 on e0.id = any (ex0.path) + join node n0 on n0.id = ex0.root_id + join node n1 on e0.id = ex0.path[array_length(ex0.path, 1)::int4] and n1.id = e0.end_id + where ex0.satisfied) +select edges_to_path(variadic ep0)::pathcomposite as p +from s0 +limit 10; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql b/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql index 51d24f730e..21cb8fbc11 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql @@ -278,3 +278,17 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite where e0.kind_id = any (array [11]::int2[])) select (s0.n0).id, (s0.n0).kind_ids, (s0.e0).id, (s0.e0).kind_id from s0; + +-- case: match (s)-[r]->(e) where s:NodeKind1 and toLower(s.name) starts with 'test' and r:EdgeKind1 and id(e) in [1, 2] return r limit 1 +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from edge e0 + join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and + lower(n0.properties ->> 'name')::text like 'test' and n0.id = e0.start_id + join node n1 + on n1.id = any (array [1, 2]::int8[]) and n1.id = e0.end_id + where e0.kind_id = any (array [11]::int2[])) +select s0.e0 as r +from s0 +limit 1; diff --git a/packages/go/cypher/models/pgsql/translate/README.md b/packages/go/cypher/models/pgsql/translate/README.md index a54054b800..af9052d88a 100644 --- a/packages/go/cypher/models/pgsql/translate/README.md +++ b/packages/go/cypher/models/pgsql/translate/README.md @@ -183,7 +183,7 @@ from s1; The translator represents pattern expansion as a [recursive CTE](https://www.postgresql.org/docs/current/queries-with.html#QUERIES-WITH-RECURSIVE) to offload as much -of the traversal work to the database. +of the traversal work to the database. Currently, all expansions are hard limited to an expansion depth of 5 steps. Consider the following openCypher query: `match (n)-[*..]->(e) return n, e`. @@ -196,12 +196,12 @@ with s0 as (with recursive ex0(root_id, next_id, depth, satisfied, is_cycle, pat | Column | type | Usage | |-------------|---------|----------------------------------------------------------------------------------------| -| `root_id` | Int4 | Node that the path originated from. Simplifies referencing the root node of each path. | -| `next_id` | Int4 | Next node to expand to. | +| `root_id` | Int8 | Node that the path originated from. Simplifies referencing the root node of each path. | +| `next_id` | Int8 | Next node to expand to. | | `depth` | Int | Depth of the current traversal. | | `satisfied` | Boolean | True if the expansion is satisfied. | | `is_cycle` | Boolean | True if the expansion is a cycle. | -| `path` | Int4[] | Array of edges in order of traversal. | +| `path` | Int8[] | Array of edges in order of traversal. | The translator then formats two queries. First is the primer query that populates the initial pathspace of the expansion: diff --git a/packages/go/cypher/models/pgsql/translate/expansion.go b/packages/go/cypher/models/pgsql/translate/expansion.go index e292d2c741..872b66c2bb 100644 --- a/packages/go/cypher/models/pgsql/translate/expansion.go +++ b/packages/go/cypher/models/pgsql/translate/expansion.go @@ -17,12 +17,123 @@ package translate import ( + "fmt" + "github.com/specterops/bloodhound/cypher/models" "github.com/specterops/bloodhound/cypher/models/pgsql" "github.com/specterops/bloodhound/cypher/models/pgsql/format" "github.com/specterops/bloodhound/dawgs/drivers/pg/model" + "github.com/specterops/bloodhound/dawgs/graph" ) +type expansionRootComponents struct { + RightNodeConstraints *Constraint + LeftNodeConstraints *Constraint + RecursiveVisible *pgsql.IdentifierSet + PrimerWhereClause pgsql.Expression + RecursiveWhereClause pgsql.Expression +} + +func prepareExpansionRootComponents(part *PatternPart, traversalStep *PatternSegment, treeTranslator *ExpressionTreeTranslator) (expansionRootComponents, error) { + expansionComponents := expansionRootComponents{ + RecursiveVisible: traversalStep.Expansion.Value.Frame.Visible.Copy(), + } + + if terminalNode, err := traversalStep.TerminalNode(); err != nil { + return expansionComponents, err + } else if rootNode, err := traversalStep.RootNode(); err != nil { + return expansionComponents, err + } else if rootNodeConstraints, err := consumeConstraintsFrom(pgsql.AsIdentifierSet(rootNode.Identifier), treeTranslator.IdentifierConstraints, part.Constraints); err != nil { + return expansionComponents, err + } else if terminalNodeConstraints, err := consumeConstraintsFrom(pgsql.AsIdentifierSet(terminalNode.Identifier), treeTranslator.IdentifierConstraints, part.Constraints); err != nil { + return expansionComponents, err + } else { + // The exclusion below is done at this step in the process since the recursive descent portion of the query no longer has + // a reference to the root node and any dependent interaction between the root and terminal nodes would require an + // additional join. By not consuming the remaining constraints for the root and terminal nodes, they become visible up + // in the outer select of the recursive CTE. + switch traversalStep.Direction { + case graph.DirectionInbound: + expansionComponents.LeftNodeConstraints = terminalNodeConstraints + expansionComponents.RightNodeConstraints = rootNodeConstraints + + expansionComponents.RecursiveVisible.Remove(rootNode.Identifier) + + case graph.DirectionOutbound: + expansionComponents.LeftNodeConstraints = rootNodeConstraints + expansionComponents.RightNodeConstraints = terminalNodeConstraints + + expansionComponents.RecursiveVisible.Remove(terminalNode.Identifier) + + default: + return expansionComponents, fmt.Errorf("graph direction %s not supported", traversalStep.Direction.String()) + } + + if edgeConstraints, err := consumeConstraintsFrom(expansionComponents.RecursiveVisible, treeTranslator.IdentifierConstraints, part.Constraints); err != nil { + return expansionComponents, err + } else { + // Set the edge constraints in the primer and recursive select where clauses + expansionComponents.PrimerWhereClause = edgeConstraints.Expression + expansionComponents.RecursiveWhereClause = pgsql.OptionalAnd(edgeConstraints.Expression, expansionConstraints(traversalStep.Expansion.Value.Binding.Identifier)) + } + } + + return expansionComponents, nil +} + +type expansionStepComponents struct { + RightNodeConstraints *Constraint + EdgeConstraints *Constraint + RecursiveVisible *pgsql.IdentifierSet + RecursiveWhereClause pgsql.Expression +} + +func prepareExpansionStepComponents(part *PatternPart, traversalStep *PatternSegment, treeTranslator *ExpressionTreeTranslator) (expansionStepComponents, error) { + expansionComponents := expansionStepComponents{ + RecursiveVisible: traversalStep.Expansion.Value.Frame.Visible.Copy(), + } + + // The exclusion in scope below is done at this step in the process since the recursive descent portion of the query no longer has + // a reference to the root and any dependent interaction between the root and terminal nodes would require an additional join. + // By not consuming the remaining constraints for the root and terminal nodes, they become visible up in the outer select of the + // recursive CTE. + + switch traversalStep.Direction { + case graph.DirectionInbound: + if rootNode, err := traversalStep.RootNode(); err != nil { + return expansionComponents, err + } else if rootNodeConstraints, err := consumeConstraintsFrom(pgsql.AsIdentifierSet(rootNode.Identifier), treeTranslator.IdentifierConstraints, part.Constraints); err != nil { + return expansionComponents, err + } else { + expansionComponents.RightNodeConstraints = rootNodeConstraints + expansionComponents.RecursiveVisible.Remove(rootNode.Identifier) + } + + case graph.DirectionOutbound: + if terminalNode, err := traversalStep.TerminalNode(); err != nil { + return expansionComponents, err + } else if terminalNodeConstraints, err := consumeConstraintsFrom(pgsql.AsIdentifierSet(terminalNode.Identifier), treeTranslator.IdentifierConstraints, part.Constraints); err != nil { + return expansionComponents, err + } else { + expansionComponents.RightNodeConstraints = terminalNodeConstraints + expansionComponents.RecursiveVisible.Remove(terminalNode.Identifier) + } + + default: + return expansionComponents, fmt.Errorf("graph direction %s not supported", traversalStep.Direction.String()) + } + + if edgeConstraints, err := consumeConstraintsFrom(expansionComponents.RecursiveVisible, treeTranslator.IdentifierConstraints, part.Constraints); err != nil { + return expansionComponents, err + } else { + // Set the edge constraints in the primer and recursive select where clauses + expansionComponents.EdgeConstraints = edgeConstraints + expansionComponents.RecursiveWhereClause = expansionConstraints(traversalStep.Expansion.Value.Binding.Identifier) + } + + return expansionComponents, nil +} + func expansionConstraints(expansionIdentifier pgsql.Identifier) pgsql.Expression { return pgsql.NewBinaryExpression( pgsql.NewBinaryExpression( @@ -453,40 +564,21 @@ func (s *Translator) buildExpansionPatternRoot(part *PatternPart, traversalStep }, }, - RecursiveStatement: pgsql.Select{ - Where: expansionConstraints(traversalStep.Expansion.Value.Binding.Identifier), - }, + RecursiveStatement: pgsql.Select{}, } ) expansion.ProjectionStatement.Projection = traversalStep.Expansion.Value.Projection - if terminalNode, err := traversalStep.TerminalNode(); err != nil { - return pgsql.Query{}, err - } else if rootNode, err := traversalStep.RootNode(); err != nil { - return pgsql.Query{}, err - } else if rootNodeConstraints, err := consumeConstraintsFrom(pgsql.AsIdentifierSet(rootNode.Identifier), s.treeTranslator.IdentifierConstraints, part.Constraints); err != nil { - return pgsql.Query{}, err - } else if terminalNodeConstraints, err := consumeConstraintsFrom(pgsql.AsIdentifierSet(terminalNode.Identifier), s.treeTranslator.IdentifierConstraints, part.Constraints); err != nil { + if expansionComponents, err := prepareExpansionRootComponents(part, traversalStep, s.treeTranslator); err != nil { return pgsql.Query{}, err } else { - // The exclusion below is done at this step in the process since the recursive descent portion of the query no longer has - // a reference to `n0` and any dependent interaction between `n0` and `n1` would require an additional join. By not - // consuming the remaining constraints for `n0` and `n1`, they become visible up in the outer select of the recursive CTE. - recursiveVisible := traversalStep.Expansion.Value.Frame.Visible.Copy() - recursiveVisible.Remove(rootNode.Identifier) - - if edgeConstraints, err := consumeConstraintsFrom(recursiveVisible, s.treeTranslator.IdentifierConstraints, part.Constraints); err != nil { - return pgsql.Query{}, err - } else { - // Set the edge constraints in the primer and recursive select where clauses - expansion.PrimerStatement.Where = edgeConstraints.Expression - expansion.RecursiveStatement.Where = pgsql.OptionalAnd(edgeConstraints.Expression, expansion.RecursiveStatement.Where) - } + expansion.PrimerStatement.Where = expansionComponents.PrimerWhereClause + expansion.RecursiveStatement.Where = expansionComponents.RecursiveWhereClause if leftNodeJoinConstraint, err := leftNodeTraversalStepConstraint(traversalStep); err != nil { return pgsql.Query{}, err - } else if leftNodeJoinCondition, err := ConjoinExpressions([]pgsql.Expression{rootNodeConstraints.Expression, leftNodeJoinConstraint}); err != nil { + } else if leftNodeJoinCondition, err := ConjoinExpressions([]pgsql.Expression{expansionComponents.LeftNodeConstraints.Expression, leftNodeJoinConstraint}); err != nil { return pgsql.Query{}, err } else if rightNodeJoinCondition, err := rightNodeTraversalStepConstraint(traversalStep); err != nil { return pgsql.Query{}, err @@ -615,9 +707,9 @@ func (s *Translator) buildExpansionPatternRoot(part *PatternPart, traversalStep } } - // If there are terminal constraints, project them as part of the projections - if terminalNodeConstraints.Expression != nil { - if terminalCriteriaProjection, err := pgsql.As[pgsql.SelectItem](terminalNodeConstraints.Expression); err != nil { + // If there are right node constraints, project them as part of the primer statement's projection + if expansionComponents.RightNodeConstraints.Expression != nil { + if terminalCriteriaProjection, err := pgsql.As[pgsql.SelectItem](expansionComponents.RightNodeConstraints.Expression); err != nil { return pgsql.Query{}, err } else { expansion.PrimerStatement.Projection = []pgsql.SelectItem{ @@ -761,25 +853,21 @@ func (s *Translator) buildExpansionPatternStep(part *PatternPart, traversalStep pgsql.CompoundIdentifier{traversalStep.Edge.Identifier, pgsql.ColumnID}, ), }, - - Where: expansionConstraints(traversalStep.Expansion.Value.Binding.Identifier), }, } ) expansion.ProjectionStatement.Projection = traversalStep.Expansion.Value.Projection - if terminalNode, err := traversalStep.TerminalNode(); err != nil { - return pgsql.Query{}, err - } else if terminalNodeConstraints, err := consumeConstraintsFrom(pgsql.AsIdentifierSet(terminalNode.Identifier), s.treeTranslator.IdentifierConstraints, part.Constraints); err != nil { - return pgsql.Query{}, err - } else if edgeConstraints, err := consumeConstraintsFrom(traversalStep.Expansion.Value.Frame.Visible, s.treeTranslator.IdentifierConstraints, part.Constraints); err != nil { + if expansionComponents, err := prepareExpansionStepComponents(part, traversalStep, s.treeTranslator); err != nil { return pgsql.Query{}, err } else { + expansion.RecursiveStatement.Where = expansionComponents.RecursiveWhereClause + if rightNodeJoinCondition, err := rightNodeTraversalStepConstraint(traversalStep); err != nil { return pgsql.Query{}, err } else { - if err := rewriteIdentifierReferences(traversalStep.Expansion.Value.Frame, []pgsql.Expression{edgeConstraints.Expression, rightNodeJoinCondition}); err != nil { + if err := rewriteIdentifierReferences(traversalStep.Expansion.Value.Frame, []pgsql.Expression{expansionComponents.EdgeConstraints.Expression, rightNodeJoinCondition}); err != nil { return pgsql.Query{}, err } @@ -794,7 +882,7 @@ func (s *Translator) buildExpansionPatternStep(part *PatternPart, traversalStep }, JoinOperator: pgsql.JoinOperator{ JoinType: pgsql.JoinTypeInner, - Constraint: edgeConstraints.Expression, + Constraint: expansionComponents.EdgeConstraints.Expression, }, }, { Table: pgsql.TableReference{ @@ -914,8 +1002,8 @@ func (s *Translator) buildExpansionPatternStep(part *PatternPart, traversalStep } // If there are terminal constraints, project them as part of the recursive lookup - if terminalNodeConstraints.Expression != nil { - if terminalCriteriaProjection, err := pgsql.As[pgsql.SelectItem](terminalNodeConstraints.Expression); err != nil { + if expansionComponents.RightNodeConstraints.Expression != nil { + if terminalCriteriaProjection, err := pgsql.As[pgsql.SelectItem](expansionComponents.RightNodeConstraints.Expression); err != nil { return pgsql.Query{}, err } else { expansion.RecursiveStatement.Projection = []pgsql.SelectItem{ diff --git a/packages/go/cypher/models/pgsql/translate/expression.go b/packages/go/cypher/models/pgsql/translate/expression.go index 479a4cbdf6..1da57810e6 100644 --- a/packages/go/cypher/models/pgsql/translate/expression.go +++ b/packages/go/cypher/models/pgsql/translate/expression.go @@ -596,6 +596,8 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato switch operator { case pgsql.OperatorContains: + newExpression.Operator = pgsql.OperatorLike + switch typedLOperand := newExpression.LOperand.(type) { case *pgsql.BinaryExpression: switch typedLOperand.Operator { @@ -607,7 +609,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato switch typedROperand := newExpression.ROperand.(type) { case *pgsql.Parameter: - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewBinaryExpression( pgsql.NewLiteral("%", pgsql.Text), pgsql.OperatorConcatenate, @@ -624,7 +625,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato } else if stringValue, isString := typedROperand.Value.(string); !isString { return fmt.Errorf("expected string but found %T as right operand for operator %s", typedROperand.Value, operator) } else { - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewLiteral("%"+stringValue+"%", rOperandDataType) } @@ -632,7 +632,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato if typeCastedROperand, err := TypeCastExpression(typedROperand, pgsql.Text); err != nil { return err } else { - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewBinaryExpression( pgsql.NewLiteral("%", pgsql.Text), pgsql.OperatorConcatenate, @@ -652,7 +651,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato typedROperand.Operator = pgsql.OperatorJSONTextField } - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewTypeCast(pgsql.NewBinaryExpression( stringLiteral, pgsql.OperatorConcatenate, @@ -672,7 +670,13 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato s.Push(newExpression) + case pgsql.OperatorRegexMatch: + newExpression.Operator = pgsql.OperatorSimilarTo + s.Push(newExpression) + case pgsql.OperatorStartsWith: + newExpression.Operator = pgsql.OperatorLike + switch typedLOperand := newExpression.LOperand.(type) { case *pgsql.BinaryExpression: switch typedLOperand.Operator { @@ -683,7 +687,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato switch typedROperand := newExpression.ROperand.(type) { case *pgsql.Parameter: - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewBinaryExpression( typedROperand, pgsql.OperatorConcatenate, @@ -696,7 +699,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato } else if stringValue, isString := typedROperand.Value.(string); !isString { return fmt.Errorf("expected string but found %T as right operand for operator %s", typedROperand.Value, operator) } else { - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewLiteral(stringValue+"%", rOperandDataType) } @@ -704,7 +706,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato if typeCastedROperand, err := TypeCastExpression(typedROperand, pgsql.Text); err != nil { return err } else { - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewBinaryExpression( typeCastedROperand, pgsql.OperatorConcatenate, @@ -720,7 +721,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato typedROperand.Operator = pgsql.OperatorJSONTextField } - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewTypeCast(pgsql.NewBinaryExpression( &pgsql.Parenthetical{ Expression: typedROperand, @@ -738,6 +738,8 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato s.Push(newExpression) case pgsql.OperatorEndsWith: + newExpression.Operator = pgsql.OperatorLike + switch typedLOperand := newExpression.LOperand.(type) { case *pgsql.BinaryExpression: switch typedLOperand.Operator { @@ -748,7 +750,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato switch typedROperand := newExpression.ROperand.(type) { case *pgsql.Parameter: - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewBinaryExpression( pgsql.NewLiteral("%", pgsql.Text), pgsql.OperatorConcatenate, @@ -761,7 +762,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato } else if stringValue, isString := typedROperand.Value.(string); !isString { return fmt.Errorf("expected string but found %T as right operand for operator %s", typedROperand.Value, operator) } else { - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewLiteral("%"+stringValue, rOperandDataType) } @@ -769,7 +769,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato if typeCastedROperand, err := TypeCastExpression(typedROperand, pgsql.Text); err != nil { return err } else { - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewBinaryExpression( pgsql.NewLiteral("%", pgsql.Text), pgsql.OperatorConcatenate, @@ -782,7 +781,6 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato typedROperand.Operator = pgsql.OperatorJSONTextField } - newExpression.Operator = pgsql.OperatorLike newExpression.ROperand = pgsql.NewTypeCast(pgsql.NewBinaryExpression( pgsql.NewLiteral("%", pgsql.Text), pgsql.OperatorConcatenate, diff --git a/packages/go/graphschema/ad/ad.go b/packages/go/graphschema/ad/ad.go index ae56f9a542..78f29c8952 100644 --- a/packages/go/graphschema/ad/ad.go +++ b/packages/go/graphschema/ad/ad.go @@ -21,6 +21,7 @@ package ad import ( "errors" + graph "github.com/specterops/bloodhound/dawgs/graph" ) diff --git a/packages/go/graphschema/azure/azure.go b/packages/go/graphschema/azure/azure.go index 00b20f190f..787ee392e6 100644 --- a/packages/go/graphschema/azure/azure.go +++ b/packages/go/graphschema/azure/azure.go @@ -21,6 +21,7 @@ package azure import ( "errors" + graph "github.com/specterops/bloodhound/dawgs/graph" ) diff --git a/packages/go/graphschema/common/common.go b/packages/go/graphschema/common/common.go index 631871c6bf..73edf123fa 100644 --- a/packages/go/graphschema/common/common.go +++ b/packages/go/graphschema/common/common.go @@ -21,6 +21,7 @@ package common import ( "errors" + graph "github.com/specterops/bloodhound/dawgs/graph" )