From 9c3fdd74cf98545bec421a95a1230f2a6fc6948f Mon Sep 17 00:00:00 2001 From: jhoward-lm <140011346+jhoward-lm@users.noreply.github.com> Date: Tue, 10 Sep 2024 11:04:34 -0500 Subject: [PATCH] refactor: add reusable function that runs callbacks in a transaction (#31) * refactor: add reusable function that runs callbacks in a transaction refactor: - optionally pass annotations through protobom StoreOptions Signed-off-by: Jonathan Howard * fix: return no results if no document IDs are passed in Signed-off-by: Jonathan Howard * fix: return no results if no document has specified annotation values Signed-off-by: Jonathan Howard * refactor: clear annotations on backend options Signed-off-by: Jonathan Howard --------- Signed-off-by: Jonathan Howard --- backends/ent/annotations.go | 116 +++---- backends/ent/backend.go | 38 ++- backends/ent/options.go | 11 + backends/ent/store.go | 613 +++++++++++++++++------------------- 4 files changed, 374 insertions(+), 404 deletions(-) diff --git a/backends/ent/annotations.go b/backends/ent/annotations.go index 4c91244..4ae8fe2 100644 --- a/backends/ent/annotations.go +++ b/backends/ent/annotations.go @@ -16,39 +16,6 @@ import ( "github.com/protobom/storage/internal/backends/ent/predicate" ) -func (backend *Backend) createAnnotations(data ...*ent.Annotation) error { - tx, err := backend.txClient() - if err != nil { - return err - } - - builders := []*ent.AnnotationCreate{} - - for idx := range data { - builder := tx.Annotation.Create(). - SetDocumentID(data[idx].DocumentID). - SetName(data[idx].Name). - SetValue(data[idx].Value). - SetIsUnique(data[idx].IsUnique) - - builders = append(builders, builder) - } - - err = tx.Annotation.CreateBulk(builders...). - OnConflict(). - UpdateNewValues(). - Exec(backend.ctx) - if err != nil && !ent.IsConstraintError(err) { - return rollback(tx, fmt.Errorf("creating annotations: %w", err)) - } - - if err := tx.Commit(); err != nil { - return rollback(tx, err) - } - - return nil -} - // AddAnnotations applies multiple named annotation values to a single document. func (backend *Backend) AddAnnotations(documentID, name string, values ...string) error { data := ent.Annotations{} @@ -60,7 +27,7 @@ func (backend *Backend) AddAnnotations(documentID, name string, values ...string }) } - return backend.createAnnotations(data...) + return backend.withTx(backend.saveAnnotations(data...)) } // AddAnnotationToDocuments applies a single named annotation value to multiple documents. @@ -74,7 +41,7 @@ func (backend *Backend) AddAnnotationToDocuments(name, value string, documentIDs }) } - return backend.createAnnotations(data...) + return backend.withTx(backend.saveAnnotations(data...)) } // ClearAnnotations removes all annotations from the specified documents. @@ -83,21 +50,14 @@ func (backend *Backend) ClearAnnotations(documentIDs ...string) error { return nil } - tx, err := backend.txClient() - if err != nil { - return err - } - - _, err = tx.Annotation.Delete().Where(annotation.DocumentIDIn(documentIDs...)).Exec(backend.ctx) - if err != nil { - return rollback(tx, fmt.Errorf("clearing annotations: %w", err)) - } - - if err := tx.Commit(); err != nil { - return rollback(tx, err) - } + return backend.withTx(func(tx *ent.Tx) error { + _, err := tx.Annotation.Delete().Where(annotation.DocumentIDIn(documentIDs...)).Exec(backend.ctx) + if err != nil { + return fmt.Errorf("clearing annotations: %w", err) + } - return nil + return nil + }) } // GetDocumentAnnotations gets all annotations for the specified @@ -143,6 +103,10 @@ func (backend *Backend) GetDocumentsByAnnotation(name string, values ...string) return nil, fmt.Errorf("querying documents table: %w", err) } + if len(ids) == 0 { + return []*sbom.Document{}, nil + } + return backend.GetDocumentsByID(ids...) } @@ -173,29 +137,23 @@ func (backend *Backend) GetDocumentUniqueAnnotation(documentID, name string) (st // RemoveAnnotations removes all annotations with the specified name from // the document, limited to a set of annotation values if specified. func (backend *Backend) RemoveAnnotations(documentID, name string, values ...string) error { - tx, err := backend.txClient() - if err != nil { - return err - } - - predicates := []predicate.Annotation{ - annotation.DocumentIDEQ(documentID), - annotation.NameEQ(name), - } - - if len(values) > 0 { - predicates = append(predicates, annotation.ValueIn(values...)) - } - - if _, err := tx.Annotation.Delete().Where(predicates...).Exec(backend.ctx); err != nil { - return rollback(tx, fmt.Errorf("removing annotations: %w", err)) - } - - if err := tx.Commit(); err != nil { - return rollback(tx, err) - } - - return nil + return backend.withTx( + func(tx *ent.Tx) error { + predicates := []predicate.Annotation{ + annotation.DocumentIDEQ(documentID), + annotation.NameEQ(name), + } + + if len(values) > 0 { + predicates = append(predicates, annotation.ValueIn(values...)) + } + + if _, err := tx.Annotation.Delete().Where(predicates...).Exec(backend.ctx); err != nil { + return fmt.Errorf("removing annotations: %w", err) + } + + return nil + }) } // SetAnnotations explicitly sets the named annotations for the specified document. @@ -209,10 +167,12 @@ func (backend *Backend) SetAnnotations(documentID, name string, values ...string // SetUniqueAnnotation sets a named annotation value that is unique to the specified document. func (backend *Backend) SetUniqueAnnotation(documentID, name, value string) error { - return backend.createAnnotations(&ent.Annotation{ - DocumentID: documentID, - Name: name, - Value: value, - IsUnique: true, - }) + return backend.withTx( + backend.saveAnnotations(&ent.Annotation{ + DocumentID: documentID, + Name: name, + Value: value, + IsUnique: true, + }), + ) } diff --git a/backends/ent/backend.go b/backends/ent/backend.go index dade3b2..db06c95 100644 --- a/backends/ent/backend.go +++ b/backends/ent/backend.go @@ -81,6 +81,22 @@ func (backend *Backend) Debug() *Backend { return backend } +func (backend *Backend) WithAnnotation(name, value string, unique bool) *Backend { + backend.Options.Annotations = append(backend.Options.Annotations, &Annotation{ + Name: name, + Value: value, + IsUnique: unique, + }) + + return backend +} + +func (backend *Backend) WithAnnotations(annotations Annotations) *Backend { + backend.Options.Annotations = append(backend.Options.Annotations, annotations...) + + return backend +} + func (backend *Backend) WithBackendOptions(opts *BackendOptions) *Backend { backend.Options = opts @@ -93,17 +109,31 @@ func (backend *Backend) WithDatabaseFile(file string) *Backend { return backend } -func (backend *Backend) txClient() (*ent.Tx, error) { +func (backend *Backend) withTx(fns ...txFunc) error { if backend.client == nil { - return nil, fmt.Errorf("%w", errUninitializedClient) + return fmt.Errorf("%w", errUninitializedClient) } tx, err := backend.client.Tx(backend.ctx) if err != nil { - return nil, fmt.Errorf("creating transactional client: %w", err) + return fmt.Errorf("creating transactional client: %w", err) } backend.ctx = ent.NewTxContext(backend.ctx, tx) - return tx, nil + for _, fn := range fns { + if err := fn(tx); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + err = fmt.Errorf("%w: rolling back transaction: %w", err, rollbackErr) + } + + return err + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("committing transaction: %w", err) + } + + return nil } diff --git a/backends/ent/options.go b/backends/ent/options.go index f4744f4..a6d3944 100644 --- a/backends/ent/options.go +++ b/backends/ent/options.go @@ -7,6 +7,8 @@ package ent import ( "errors" + + "github.com/protobom/storage/internal/backends/ent" ) // Enable SQLite foreign key support. @@ -18,11 +20,20 @@ var ( ) type ( + // Annotation is the model entity for the Annotation schema. + Annotation = ent.Annotation + + // Annotations is a parsable slice of Annotation. + Annotations = ent.Annotations + // BackendOptions contains options specific to the protobom ent backend. BackendOptions struct { // DatabaseFile is the file path of the SQLite database to be created. DatabaseFile string + // Annotations is a slice of annotations to apply to stored document. + Annotations + // Debug configures the ent client to output all SQL statements during execution. Debug bool } diff --git a/backends/ent/store.go b/backends/ent/store.go index 7c61954..b3f8c00 100644 --- a/backends/ent/store.go +++ b/backends/ent/store.go @@ -8,6 +8,7 @@ package ent import ( "context" "fmt" + "slices" "github.com/protobom/protobom/pkg/sbom" "github.com/protobom/protobom/pkg/storage" @@ -28,6 +29,8 @@ type ( metadataIDKey struct{} nodeIDKey struct{} nodeListIDKey struct{} + + txFunc func(*ent.Tx) error ) // Store implements the storage.Storer interface. @@ -42,436 +45,402 @@ func (backend *Backend) Store(doc *sbom.Document, opts *storage.StoreOptions) er } } - if _, ok := opts.BackendOptions.(*BackendOptions); !ok { + backendOpts, ok := opts.BackendOptions.(*BackendOptions) + if !ok { return fmt.Errorf("%w", errInvalidEntOptions) } - tx, err := backend.client.Tx(backend.ctx) - if err != nil { - return fmt.Errorf("creating transactional client: %w", err) - } - - backend.ctx = ent.NewTxContext(backend.ctx, tx) - - if err := tx.Document.Create(). - SetID(doc.Metadata.Id). - OnConflict(). - Ignore(). - Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) { - return rollback(tx, fmt.Errorf("ent.Document: %w", err)) - } - - if err := backend.saveMetadata(doc.Metadata); err != nil { - return rollback(tx, err) - } + // Append annotations from opts parameter with any previously set on the backend. + annotations := slices.Concat(backend.Options.Annotations, backendOpts.Annotations) + clear(backend.Options.Annotations) - if err := backend.saveNodeList(doc.NodeList); err != nil { - return rollback(tx, err) - } - - if err := tx.Commit(); err != nil { - return rollback(tx, err) + // Set each annotation's document ID if not specified. + for _, a := range annotations { + if a.DocumentID == "" { + a.DocumentID = doc.Metadata.Id + } } - return nil + return backend.withTx( + func(tx *ent.Tx) error { + return tx.Document.Create(). + SetID(doc.Metadata.Id). + OnConflict(). + Ignore(). + Exec(backend.ctx) + }, + backend.saveAnnotations(annotations...), + backend.saveMetadata(doc.Metadata), + backend.saveNodeList(doc.NodeList), + ) } -func (backend *Backend) saveDocumentTypes(docTypes []*sbom.DocumentType) error { - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) - } - - tx := ent.TxFromContext(backend.ctx) - - for _, dt := range docTypes { - typeName := documenttype.Type(dt.Type.String()) +func (backend *Backend) saveAnnotations(annotations ...*ent.Annotation) txFunc { + return func(tx *ent.Tx) error { + builders := []*ent.AnnotationCreate{} - newDocType := tx.DocumentType.Create(). - SetNillableType(&typeName). - SetNillableName(dt.Name). - SetNillableDescription(dt.Description) + for idx := range annotations { + builder := tx.Annotation.Create(). + SetDocumentID(annotations[idx].DocumentID). + SetName(annotations[idx].Name). + SetValue(annotations[idx].Value). + SetIsUnique(annotations[idx].IsUnique) - if metadataID, ok := backend.ctx.Value(metadataIDKey{}).(string); ok { - newDocType.SetMetadataID(metadataID) + builders = append(builders, builder) } - err := newDocType.OnConflict().Ignore().Exec(backend.ctx) + err := tx.Annotation.CreateBulk(builders...). + OnConflict(). + UpdateNewValues(). + Exec(backend.ctx) if err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.DocumentType: %w", err) + return fmt.Errorf("creating annotations: %w", err) } - } - return nil + return nil + } } -func (backend *Backend) saveEdges(edges []*sbom.Edge) error { - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) - } +func (backend *Backend) saveDocumentTypes(docTypes []*sbom.DocumentType) txFunc { + return func(tx *ent.Tx) error { + for _, dt := range docTypes { + typeName := documenttype.Type(dt.Type.String()) - tx := ent.TxFromContext(backend.ctx) + newDocType := tx.DocumentType.Create(). + SetNillableType(&typeName). + SetNillableName(dt.Name). + SetNillableDescription(dt.Description) - for _, edge := range edges { - for _, toID := range edge.To { - newEdgeType := tx.EdgeType.Create(). - SetType(edgetype.Type(edge.Type.String())). - SetFromID(edge.From). - SetToID(toID) + if metadataID, ok := backend.ctx.Value(metadataIDKey{}).(string); ok { + newDocType.SetMetadataID(metadataID) + } - err := newEdgeType.OnConflict().Ignore().Exec(backend.ctx) + err := newDocType.OnConflict().Ignore().Exec(backend.ctx) if err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.Node: %w", err) + return fmt.Errorf("ent.DocumentType: %w", err) } } - } - return nil -} - -func (backend *Backend) saveExternalReferences(refs []*sbom.ExternalReference) error { - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) + return nil } +} - tx := ent.TxFromContext(backend.ctx) - - for _, ref := range refs { - newRef := tx.ExternalReference.Create(). - SetURL(ref.Url). - SetComment(ref.Comment). - SetAuthority(ref.Authority). - SetType(externalreference.Type(ref.Type.String())) - - if nodeID, ok := backend.ctx.Value(nodeIDKey{}).(string); ok { - newRef.SetNodeID(nodeID) - } - - id, err := newRef.OnConflict().Ignore().ID(backend.ctx) - if err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.ExternalReference: %w", err) +func (backend *Backend) saveEdges(edges []*sbom.Edge) txFunc { + return func(tx *ent.Tx) error { + for _, edge := range edges { + for _, toID := range edge.To { + newEdgeType := tx.EdgeType.Create(). + SetType(edgetype.Type(edge.Type.String())). + SetFromID(edge.From). + SetToID(toID) + + err := newEdgeType.OnConflict().Ignore().Exec(backend.ctx) + if err != nil && !ent.IsConstraintError(err) { + return fmt.Errorf("ent.Node: %w", err) + } + } } - backend.ctx = context.WithValue(backend.ctx, externalReferenceIDKey{}, id) - - if err := backend.saveHashesEntries(ref.Hashes); err != nil { - return err - } + return nil } - - return nil } -func (backend *Backend) saveHashesEntries(hashes map[int32]string) error { - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) - } - - tx := ent.TxFromContext(backend.ctx) - entries := []*ent.HashesEntryCreate{} - - for alg, content := range hashes { - algName := sbom.HashAlgorithm(alg).String() +func (backend *Backend) saveExternalReferences(refs []*sbom.ExternalReference) txFunc { + return func(tx *ent.Tx) error { + for _, ref := range refs { + newRef := tx.ExternalReference.Create(). + SetURL(ref.Url). + SetComment(ref.Comment). + SetAuthority(ref.Authority). + SetType(externalreference.Type(ref.Type.String())) + + if nodeID, ok := backend.ctx.Value(nodeIDKey{}).(string); ok { + newRef.SetNodeID(nodeID) + } - entry := tx.HashesEntry.Create(). - SetHashAlgorithmType(hashesentry.HashAlgorithmType(algName)). - SetHashData(content) + id, err := newRef.OnConflict().Ignore().ID(backend.ctx) + if err != nil && !ent.IsConstraintError(err) { + return fmt.Errorf("ent.ExternalReference: %w", err) + } - if externalReferenceID, ok := backend.ctx.Value(externalReferenceIDKey{}).(int); ok { - entry.SetExternalReferenceID(externalReferenceID) - } + backend.ctx = context.WithValue(backend.ctx, externalReferenceIDKey{}, id) - if nodeID, ok := backend.ctx.Value(nodeIDKey{}).(string); ok { - entry.SetNodeID(nodeID) + if err := backend.saveHashesEntries(ref.Hashes)(tx); err != nil { + return err + } } - entries = append(entries, entry) - } - - if err := tx.HashesEntry.CreateBulk(entries...). - Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.HashesEntry: %w", err) + return nil } - - return nil } -func (backend *Backend) saveIdentifiersEntries(idents map[int32]string) error { - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) - } +func (backend *Backend) saveHashesEntries(hashes map[int32]string) txFunc { + return func(tx *ent.Tx) error { + entries := []*ent.HashesEntryCreate{} - tx := ent.TxFromContext(backend.ctx) - entries := []*ent.IdentifiersEntryCreate{} + for alg, content := range hashes { + algName := sbom.HashAlgorithm(alg).String() - for typ, value := range idents { - typeName := sbom.SoftwareIdentifierType(typ).String() + entry := tx.HashesEntry.Create(). + SetHashAlgorithmType(hashesentry.HashAlgorithmType(algName)). + SetHashData(content) - entry := tx.IdentifiersEntry.Create(). - SetSoftwareIdentifierType(identifiersentry.SoftwareIdentifierType(typeName)). - SetSoftwareIdentifierValue(value) + if externalReferenceID, ok := backend.ctx.Value(externalReferenceIDKey{}).(int); ok { + entry.SetExternalReferenceID(externalReferenceID) + } - if nodeID, ok := backend.ctx.Value(nodeIDKey{}).(string); ok { - entry.SetNodeID(nodeID) + if nodeID, ok := backend.ctx.Value(nodeIDKey{}).(string); ok { + entry.SetNodeID(nodeID) + } + + entries = append(entries, entry) } - entries = append(entries, entry) - } + if err := tx.HashesEntry.CreateBulk(entries...). + Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) { + return fmt.Errorf("ent.HashesEntry: %w", err) + } - if err := tx.IdentifiersEntry.CreateBulk(entries...). - Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.IdentifiersEntry: %w", err) + return nil } - - return nil } -func (backend *Backend) saveMetadata(md *sbom.Metadata) error { - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) - } +func (backend *Backend) saveIdentifiersEntries(idents map[int32]string) txFunc { + return func(tx *ent.Tx) error { + entries := []*ent.IdentifiersEntryCreate{} - tx := ent.TxFromContext(backend.ctx) + for typ, value := range idents { + typeName := sbom.SoftwareIdentifierType(typ).String() - newMetadata := tx.Metadata.Create(). - SetID(md.Id). - SetDocumentID(md.Id). - SetVersion(md.Version). - SetName(md.Name). - SetComment(md.Comment). - SetDate(md.Date.AsTime()) - - err := newMetadata.OnConflict().Ignore().Exec(backend.ctx) - if err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.Metadata: %w", err) - } + entry := tx.IdentifiersEntry.Create(). + SetSoftwareIdentifierType(identifiersentry.SoftwareIdentifierType(typeName)). + SetSoftwareIdentifierValue(value) - backend.ctx = context.WithValue(backend.ctx, metadataIDKey{}, md.Id) + if nodeID, ok := backend.ctx.Value(nodeIDKey{}).(string); ok { + entry.SetNodeID(nodeID) + } - if err := backend.savePersons(md.Authors); err != nil { - return err - } + entries = append(entries, entry) + } - if err := backend.saveDocumentTypes(md.DocumentTypes); err != nil { - return err - } + if err := tx.IdentifiersEntry.CreateBulk(entries...). + Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) { + return fmt.Errorf("ent.IdentifiersEntry: %w", err) + } - if err := backend.saveTools(md.Tools); err != nil { - return err + return nil } - - return nil } -func (backend *Backend) saveNodeList(nodeList *sbom.NodeList) error { - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) - } - - tx := ent.TxFromContext(backend.ctx) - newNodeList := tx.NodeList.Create(). - SetRootElements(nodeList.RootElements) - - if documentID, ok := backend.ctx.Value(metadataIDKey{}).(string); ok { - newNodeList.SetDocumentID(documentID) - } - - id, err := newNodeList.OnConflict().Ignore().ID(backend.ctx) - if err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.NodeList: %w", err) - } +func (backend *Backend) saveMetadata(md *sbom.Metadata) txFunc { + return func(tx *ent.Tx) error { + newMetadata := tx.Metadata.Create(). + SetID(md.Id). + SetDocumentID(md.Id). + SetVersion(md.Version). + SetName(md.Name). + SetComment(md.Comment). + SetDate(md.Date.AsTime()) + + err := newMetadata.OnConflict().Ignore().Exec(backend.ctx) + if err != nil && !ent.IsConstraintError(err) { + return fmt.Errorf("ent.Metadata: %w", err) + } - backend.ctx = context.WithValue(backend.ctx, nodeListIDKey{}, id) + backend.ctx = context.WithValue(backend.ctx, metadataIDKey{}, md.Id) - if err := backend.saveNodes(nodeList.Nodes); err != nil { - return err - } + for _, fn := range []txFunc{ + backend.savePersons(md.Authors), + backend.saveDocumentTypes(md.DocumentTypes), + backend.saveTools(md.Tools), + } { + if err := fn(tx); err != nil { + return err + } + } - // Update nodes of this node list with their typed edges. - if err := backend.saveEdges(nodeList.Edges); err != nil { - return err + return nil } - - return nil } -func (backend *Backend) saveNodes(nodes []*sbom.Node) error { //nolint:cyclop - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) - } +func (backend *Backend) saveNodeList(nodeList *sbom.NodeList) txFunc { + return func(tx *ent.Tx) error { + newNodeList := tx.NodeList.Create(). + SetRootElements(nodeList.RootElements) - for _, n := range nodes { - newNode := backend.newNodeCreate(n) + if documentID, ok := backend.ctx.Value(metadataIDKey{}).(string); ok { + newNodeList.SetDocumentID(documentID) + } - if err := newNode.OnConflict().Ignore().Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.Node: %w", err) + id, err := newNodeList.OnConflict().Ignore().ID(backend.ctx) + if err != nil && !ent.IsConstraintError(err) { + return fmt.Errorf("ent.NodeList: %w", err) } - backend.ctx = context.WithValue(backend.ctx, nodeIDKey{}, n.Id) + backend.ctx = context.WithValue(backend.ctx, nodeListIDKey{}, id) - if err := backend.saveExternalReferences(n.ExternalReferences); err != nil { + if err := backend.saveNodes(nodeList.Nodes)(tx); err != nil { return err } - if err := backend.savePersons(n.Originators); err != nil { + // Update nodes of this node list with their typed edges. + if err := backend.saveEdges(nodeList.Edges)(tx); err != nil { return err } - if err := backend.savePersons(n.Suppliers); err != nil { - return err - } + return nil + } +} - if err := backend.savePurposes(n.PrimaryPurpose); err != nil { - return err +func (backend *Backend) saveNode(n *sbom.Node) txFunc { + return func(tx *ent.Tx) error { + newNode := tx.Node.Create(). + SetID(n.Id). + SetAttribution(n.Attribution). + SetBuildDate(n.BuildDate.AsTime()). + SetComment(n.Comment). + SetCopyright(n.Copyright). + SetDescription(n.Description). + SetFileName(n.FileName). + SetFileTypes(n.FileTypes). + SetLicenseComments(n.LicenseComments). + SetLicenseConcluded(n.LicenseConcluded). + SetLicenses(n.Licenses). + SetName(n.Name). + SetReleaseDate(n.ReleaseDate.AsTime()). + SetSourceInfo(n.SourceInfo). + SetSummary(n.Summary). + SetType(node.Type(n.Type.String())). + SetURLDownload(n.UrlDownload). + SetURLHome(n.UrlHome). + SetValidUntilDate(n.ValidUntilDate.AsTime()). + SetVersion(n.Version) + + if nodeListID, ok := backend.ctx.Value(nodeListIDKey{}).(int); ok { + newNode.AddNodeListIDs(nodeListID) } - if err := backend.saveHashesEntries(n.Hashes); err != nil { - return err + if err := newNode.OnConflict().Ignore().Exec(backend.ctx); err != nil && !ent.IsConstraintError(err) { + return fmt.Errorf("ent.Node: %w", err) } - if err := backend.saveIdentifiersEntries(n.Identifiers); err != nil { - return err - } + return nil } - - return nil } -func (backend *Backend) savePersons(persons []*sbom.Person) error { - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) - } +func (backend *Backend) saveNodes(nodes []*sbom.Node) txFunc { + return func(tx *ent.Tx) error { + for _, n := range nodes { + if err := backend.saveNode(n)(tx); err != nil { + return fmt.Errorf("ent.Node: %w", err) + } - tx := ent.TxFromContext(backend.ctx) + backend.ctx = context.WithValue(backend.ctx, nodeIDKey{}, n.Id) + + for _, fn := range []txFunc{ + backend.saveExternalReferences(n.ExternalReferences), + backend.savePersons(n.Originators), + backend.savePersons(n.Suppliers), + backend.savePurposes(n.PrimaryPurpose), + backend.saveHashesEntries(n.Hashes), + backend.saveIdentifiersEntries(n.Identifiers), + } { + if err := fn(tx); err != nil { + return err + } + } + } - for _, p := range persons { - newPerson := tx.Person.Create(). - SetName(p.Name). - SetEmail(p.Email). - SetIsOrg(p.IsOrg). - SetPhone(p.Phone). - SetURL(p.Url) + return nil + } +} - if contactOwnerID, ok := backend.ctx.Value(contactOwnerIDKey{}).(int); ok { - newPerson.SetContactOwnerID(contactOwnerID) - } +func (backend *Backend) savePersons(persons []*sbom.Person) txFunc { + return func(tx *ent.Tx) error { + for _, p := range persons { + newPerson := tx.Person.Create(). + SetName(p.Name). + SetEmail(p.Email). + SetIsOrg(p.IsOrg). + SetPhone(p.Phone). + SetURL(p.Url) + + if contactOwnerID, ok := backend.ctx.Value(contactOwnerIDKey{}).(int); ok { + newPerson.SetContactOwnerID(contactOwnerID) + } - if metadataID, ok := backend.ctx.Value(metadataIDKey{}).(string); ok { - newPerson.SetMetadataID(metadataID) - } + if metadataID, ok := backend.ctx.Value(metadataIDKey{}).(string); ok { + newPerson.SetMetadataID(metadataID) + } - id, err := newPerson.OnConflict().Ignore().ID(backend.ctx) - if err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.ExternalReference: %w", err) - } + id, err := newPerson.OnConflict().Ignore().ID(backend.ctx) + if err != nil && !ent.IsConstraintError(err) { + return fmt.Errorf("ent.ExternalReference: %w", err) + } - backend.ctx = context.WithValue(backend.ctx, contactOwnerIDKey{}, id) + backend.ctx = context.WithValue(backend.ctx, contactOwnerIDKey{}, id) - if err := backend.savePersons(p.Contacts); err != nil { - return err + if err := backend.savePersons(p.Contacts)(tx); err != nil { + return err + } } - } - return nil + return nil + } } -func (backend *Backend) savePurposes(purposes []sbom.Purpose) error { - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) - } +func (backend *Backend) savePurposes(purposes []sbom.Purpose) txFunc { + return func(tx *ent.Tx) error { + builders := []*ent.PurposeCreate{} - tx := ent.TxFromContext(backend.ctx) - builders := []*ent.PurposeCreate{} + for idx := range purposes { + newPurpose := tx.Purpose.Create(). + SetPrimaryPurpose(purpose.PrimaryPurpose(purposes[idx].String())) - for idx := range purposes { - newPurpose := tx.Purpose.Create(). - SetPrimaryPurpose(purpose.PrimaryPurpose(purposes[idx].String())) + if nodeID, ok := backend.ctx.Value(nodeIDKey{}).(string); ok { + newPurpose.SetNodeID(nodeID) + } - if nodeID, ok := backend.ctx.Value(nodeIDKey{}).(string); ok { - newPurpose.SetNodeID(nodeID) + builders = append(builders, newPurpose) } - builders = append(builders, newPurpose) - } + err := tx.Purpose.CreateBulk(builders...). + OnConflict(). + Ignore(). + Exec(backend.ctx) + if err != nil && !ent.IsConstraintError(err) { + return fmt.Errorf("ent.Tool: %w", err) + } - err := tx.Purpose.CreateBulk(builders...). - OnConflict(). - Ignore(). - Exec(backend.ctx) - if err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.Tool: %w", err) + return nil } - - return nil } -func (backend *Backend) saveTools(tools []*sbom.Tool) error { - if backend.client == nil { - return fmt.Errorf("%w", errUninitializedClient) - } +func (backend *Backend) saveTools(tools []*sbom.Tool) txFunc { + return func(tx *ent.Tx) error { + builders := []*ent.ToolCreate{} - tx := ent.TxFromContext(backend.ctx) - builders := []*ent.ToolCreate{} + for _, t := range tools { + newTool := tx.Tool.Create(). + SetName(t.Name). + SetVersion(t.Version). + SetVendor(t.Vendor) - for _, t := range tools { - newTool := tx.Tool.Create(). - SetName(t.Name). - SetVersion(t.Version). - SetVendor(t.Vendor) + if metadataID, ok := backend.ctx.Value(metadataIDKey{}).(string); ok { + newTool.SetMetadataID(metadataID) + } - if metadataID, ok := backend.ctx.Value(metadataIDKey{}).(string); ok { - newTool.SetMetadataID(metadataID) + builders = append(builders, newTool) } - builders = append(builders, newTool) - } - - err := tx.Tool.CreateBulk(builders...). - OnConflict(). - Ignore(). - Exec(backend.ctx) - if err != nil && !ent.IsConstraintError(err) { - return fmt.Errorf("ent.Tool: %w", err) - } - - return nil -} - -func (backend *Backend) newNodeCreate(n *sbom.Node) *ent.NodeCreate { - tx := ent.TxFromContext(backend.ctx) - - newNode := tx.Node.Create(). - SetID(n.Id). - SetAttribution(n.Attribution). - SetBuildDate(n.BuildDate.AsTime()). - SetComment(n.Comment). - SetCopyright(n.Copyright). - SetDescription(n.Description). - SetFileName(n.FileName). - SetFileTypes(n.FileTypes). - SetLicenseComments(n.LicenseComments). - SetLicenseConcluded(n.LicenseConcluded). - SetLicenses(n.Licenses). - SetName(n.Name). - SetReleaseDate(n.ReleaseDate.AsTime()). - SetSourceInfo(n.SourceInfo). - SetSummary(n.Summary). - SetType(node.Type(n.Type.String())). - SetURLDownload(n.UrlDownload). - SetURLHome(n.UrlHome). - SetValidUntilDate(n.ValidUntilDate.AsTime()). - SetVersion(n.Version) - - if nodeListID, ok := backend.ctx.Value(nodeListIDKey{}).(int); ok { - newNode.AddNodeListIDs(nodeListID) - } - - return newNode -} + err := tx.Tool.CreateBulk(builders...). + OnConflict(). + Ignore(). + Exec(backend.ctx) + if err != nil && !ent.IsConstraintError(err) { + return fmt.Errorf("ent.Tool: %w", err) + } -func rollback(tx *ent.Tx, err error) error { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return fmt.Errorf("rolling back transaction: %w", rollbackErr) + return nil } - - return err }