basic git remote functionality
diff --git a/reclass/storage/yaml_git/__init__.py b/reclass/storage/yaml_git/__init__.py
index d568d18..51a9713 100644
--- a/reclass/storage/yaml_git/__init__.py
+++ b/reclass/storage/yaml_git/__init__.py
@@ -4,8 +4,10 @@
# This file is part of reclass
import collections
+import distutils.version
import fnmatch
import os
+import paramiko
# Squelch warning on centos7 due to upgrading cffi
# see https://github.com/saltstack/salt/pull/39871
@@ -54,14 +56,69 @@
class GitRepo(object):
- def __init__(self, name):
- self.name = name
- if self.name.startswith('file://'):
- self.name = self.name[7:]
- self.repo = pygit2.Repository(self.name)
+ def __init__(self, url):
+ self.transport, _, self.url = url.partition('://')
+ self.name = self.url.replace('/', '_')
+ self.credentials = None
+ self.remotecallbacks = None
+ self._init_repo()
+ self._fetch()
self.branches = self.repo.listall_branches()
self.files = self.files_in_repo()
+ def _init_repo(self):
+ self.cache_dir = '{0}/{1}/{2}'.format(os.path.expanduser("~"), '.reclass/cache/git', self.name)
+ if os.path.exists(self.cache_dir):
+ self.repo = pygit2.Repository(self.cache_dir)
+ else:
+ os.makedirs(self.cache_dir)
+ self.repo = pygit2.init_repository(self.cache_dir, bare=True)
+
+ if not self.repo.remotes:
+ self.repo.create_remote('origin', self.url)
+
+ if 'ssh' in self.transport:
+ if '@' in self.url:
+ user, _, _ = self.url.partition('@')
+ else:
+ user = 'gitlab'
+ pygit2_version = pygit2.__version__
+ if distutils.version.LooseVersion(pygit2_version) >= distutils.version.LooseVersion('0.23.2'):
+ self.remotecallbacks = pygit2.RemoteCallbacks(credentials=pygit2.KeypairFromAgent(user))
+ self.credentials = None
+ else:
+ self.remotecallbacks = None
+ self.credentials = pygit2.KeypairFromAgent(user)
+
+ def _fetch(self):
+ origin = self.repo.remotes[0]
+ fetch_kwargs = {}
+ if self.remotecallbacks is not None:
+ fetch_kwargs['callbacks'] = self.remotecallbacks
+ if self.credentials is not None:
+ origin.credentials = self.credentials
+ fetch_results = origin.fetch(**fetch_kwargs)
+
+ remote_branches = self.repo.listall_branches(pygit2.GIT_BRANCH_REMOTE)
+ local_branches = self.repo.listall_branches()
+ for remote_branch_name in remote_branches:
+ _, _, local_branch_name = remote_branch_name.partition('/')
+ remote_branch = self.repo.lookup_branch(remote_branch_name, pygit2.GIT_BRANCH_REMOTE)
+ if local_branch_name not in local_branches:
+ local_branch = self.repo.create_branch(local_branch_name, self.repo[remote_branch.target.hex])
+ local_branch.upstream = remote_branch
+ else:
+ local_branch = self.repo.lookup_branch(local_branch_name)
+ if local_branch.target != remote_branch.target:
+ local_branch.set_target(remote_branch.target)
+
+ local_branches = self.repo.listall_branches()
+ for local_branch_name in local_branches:
+ remote_branch_name = '{0}/{1}'.format(origin.name, local_branch_name)
+ if remote_branch_name not in remote_branches:
+ local_branch = self.repo.lookup_branch(local_branch_name)
+ local.branch.delete()
+
def get(self, id):
return self.repo.get(id)