clair/database/pgsql/namespace.go
Sida Chen 0c1b80b2ed pgsql: Implement database queries for detector relationship
* Refactor layer and ancestry
* Add tests
* Fix bugs introduced when the queries were moved
2018-10-08 11:27:15 -04:00

96 lines
2.5 KiB
Go

// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"database/sql"
"sort"
"github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr"
)
const (
searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2`
)
// PersistNamespaces soi namespaces into database.
func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error {
if len(namespaces) == 0 {
return nil
}
// Sorting is needed before inserting into database to prevent deadlock.
sort.Slice(namespaces, func(i, j int) bool {
return namespaces[i].Name < namespaces[j].Name &&
namespaces[i].VersionFormat < namespaces[j].VersionFormat
})
keys := make([]interface{}, len(namespaces)*2)
for i, ns := range namespaces {
if ns.Name == "" || ns.VersionFormat == "" {
return commonerr.NewBadRequestError("Empty namespace name or version format is not allowed")
}
keys[i*2] = ns.Name
keys[i*2+1] = ns.VersionFormat
}
_, err := tx.Exec(queryPersistNamespace(len(namespaces)), keys...)
if err != nil {
return handleError("queryPersistNamespace", err)
}
return nil
}
func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.NullInt64, error) {
if len(namespaces) == 0 {
return nil, nil
}
keys := make([]interface{}, len(namespaces)*2)
nsMap := map[database.Namespace]sql.NullInt64{}
for i, n := range namespaces {
keys[i*2] = n.Name
keys[i*2+1] = n.VersionFormat
nsMap[n] = sql.NullInt64{}
}
rows, err := tx.Query(querySearchNamespace(len(namespaces)), keys...)
if err != nil {
return nil, handleError("searchNamespace", err)
}
defer rows.Close()
var (
id sql.NullInt64
ns database.Namespace
)
for rows.Next() {
err := rows.Scan(&id, &ns.Name, &ns.VersionFormat)
if err != nil {
return nil, handleError("searchNamespace", err)
}
nsMap[ns] = id
}
ids := make([]sql.NullInt64, len(namespaces))
for i, ns := range namespaces {
ids[i] = nsMap[ns]
}
return ids, nil
}