]> git.notmuchmail.org Git - notmuch/blobdiff - notmuch-git.py
Merge branch 'release'
[notmuch] / notmuch-git.py
index a75de135a8378869b3b1a6f833df941c36530cf2..ee87bec611c632cec23875ae7386d8e641081536 100644 (file)
@@ -31,7 +31,6 @@ import locale as _locale
 import logging as _logging
 import os as _os
 import re as _re
-import shutil as _shutil
 import subprocess as _subprocess
 import sys as _sys
 import tempfile as _tempfile
@@ -248,13 +247,13 @@ def count_messages(prefix=None):
         stdout=_subprocess.PIPE, wait=True)
     if status != 0:
         _LOG.error("failed to run notmuch config")
-        sys.exit(1)
+        _sys.exit(1)
     return int(stdout.rstrip())
 
 def get_tags(prefix=None):
     "Get a list of tags with a given prefix."
     (status, stdout, stderr) = _spawn(
-        args=['notmuch', 'search', '--query=sexp', '--output=tags', _tag_query(prefix)],
+        args=['notmuch', 'search', '--exclude=false', '--query=sexp', '--output=tags', _tag_query(prefix)],
         stdout=_subprocess.PIPE, wait=True)
     return [tag for tag in stdout.splitlines()]
 
@@ -369,7 +368,7 @@ class CachedIndex:
         _git(args=['read-tree', self.current_treeish], wait=True)
 
 
-def check_safe_fraction(status):
+def _check_fraction(change):
     safe = 0.1
     conf = _notmuch_config_get ('git.safe_fraction')
     if conf and conf != '':
@@ -377,10 +376,10 @@ def check_safe_fraction(status):
 
     total = count_messages (TAG_PREFIX)
     if total == 0:
-        _LOG.error('No existing tags with given prefix, stopping.'.format(safe))
+        _LOG.error('No existing tags with given prefix, stopping.')
         _LOG.error('Use --force to override.')
         exit(1)
-    change = len(status['added'])+len(status['deleted'])
+
     fraction = change/total
     _LOG.debug('total messages {:d}, change: {:d}, fraction: {:f}'.format(total,change,fraction))
     if fraction > safe:
@@ -388,6 +387,25 @@ def check_safe_fraction(status):
         _LOG.error('Use --force to override or reconfigure git.safe_fraction.')
         exit(1)
 
+def check_safe_fraction(status):
+
+    change = len(status['added'])+len(status['deleted'])
+    _check_fraction(change)
+
+def check_diff_fraction():
+
+    # check number of directories (i.e. messages) changed.
+    change_set = set()
+
+    with _git(args=['diff', '--name-only', 'HEAD', '@{upstream}'],
+              stdout=_subprocess.PIPE) as git:
+        for path in git.stdout:
+            change_set.add(_os.path.dirname(path))
+
+    change=len(change_set)
+    _check_fraction(change)
+
+
 def commit(treeish='HEAD', message=None, force=False):
     """
     Commit prefix-matching tags from the notmuch database to Git.
@@ -620,6 +638,15 @@ def push(repository=None, refspecs=None):
     _git(args=args, wait=True)
 
 
+def reset(force=False):
+    """
+    reset the local git branch to match the remote one
+    """
+    if not force:
+        check_diff_fraction()
+
+    _git(args=["reset","--soft","origin/master"],wait=True)
+
 def status():
     """
     Show pending updates in notmuch or git repo.
@@ -698,6 +725,32 @@ def _is_unmerged(ref='@{upstream}'):
         stdout=_subprocess.PIPE, wait=True)
     return base != fetch_head
 
+class DatabaseCache:
+    def __init__(self):
+        try:
+            from notmuch2 import Database
+            self._notmuch = Database()
+        except ImportError:
+            self._notmuch = None
+        self._known = {}
+
+    def known(self,id):
+        if id in self._known:
+            return self._known[id];
+
+        if self._notmuch:
+            try:
+                _ = self._notmuch.find(id)
+                self._known[id] = True
+            except LookupError:
+                self._known[id] = False
+        else:
+            (_, stdout, stderr) = _spawn(
+                args=['notmuch', 'search', '--exclude=false', '--output=files', 'id:{0}'.format(id)],
+                stdout=_subprocess.PIPE,
+                wait=True)
+            self._known[id] = stdout != None
+        return self._known[id]
 
 @timed
 def get_status():
@@ -705,14 +758,11 @@ def get_status():
         'deleted': {},
         'missing': {},
         }
+    db = DatabaseCache()
     with PrivateIndex(repo=NOTMUCH_GIT_DIR, prefix=TAG_PREFIX) as index:
         maybe_deleted = index.diff(filter='D')
         for id, tags in maybe_deleted.items():
-            (_, stdout, stderr) = _spawn(
-                args=['notmuch', 'search', '--output=files', 'id:{0}'.format(id)],
-                stdout=_subprocess.PIPE,
-                wait=True)
-            if stdout:
+            if db.known(id):
                 status['deleted'][id] = tags
             else:
                 status['missing'][id] = tags
@@ -738,6 +788,7 @@ class PrivateIndex:
         self.lastmod = None
         self.checksum = None
         self._load_cache_file()
+        self.file_tree = None
         self._index_tags()
 
     def __enter__(self):
@@ -763,6 +814,43 @@ class PrivateIndex:
             _LOG.error("Error decoding cache")
             _sys.exit(1)
 
+    @timed
+    def _read_file_tree(self):
+        self.file_tree = {}
+
+        with _git(
+                args=['ls-files', 'tags'],
+                additional_env={'GIT_INDEX_FILE': self.index_path},
+                stdout=_subprocess.PIPE) as git:
+            for file in git.stdout:
+                dir=_os.path.dirname(file)
+                tag=_os.path.basename(file).rstrip()
+                if dir not in self.file_tree:
+                    self.file_tree[dir]=[tag]
+                else:
+                    self.file_tree[dir].append(tag)
+
+
+    def _clear_tags_for_message(self, id):
+        """
+        Clear any existing index entries for message 'id'
+
+        Neither 'id' nor the tags in 'tags' should be encoded/escaped.
+        """
+
+        if self.file_tree == None:
+            self._read_file_tree()
+
+        dir = _id_path(id)
+
+        if dir not in self.file_tree:
+            return
+
+        for file in self.file_tree[dir]:
+            line = '0 0000000000000000000000000000000000000000\t{:s}/{:s}\n'.format(dir,file)
+            yield line
+
+
     @timed
     def _index_tags(self):
         "Write notmuch tags to private git index."
@@ -798,7 +886,7 @@ class PrivateIndex:
                         if tag.startswith(prefix)]
                     id = _xapian_unquote(string=id)
                     if clear_tags:
-                        for line in _clear_tags_for_message(index=self.index_path, id=id):
+                        for line in self._clear_tags_for_message(id=id):
                             git.stdin.write(line)
                     for line in _index_tags_for_message(
                             id=id, status='A', tags=tags):
@@ -835,24 +923,6 @@ def _read_index_checksum (index_path):
     except FileNotFoundError:
         return None
 
-
-def _clear_tags_for_message(index, id):
-    """
-    Clear any existing index entries for message 'id'
-
-    Neither 'id' nor the tags in 'tags' should be encoded/escaped.
-    """
-
-    dir = _id_path(id)
-
-    with _git(
-            args=['ls-files', dir],
-            additional_env={'GIT_INDEX_FILE': index},
-            stdout=_subprocess.PIPE) as git:
-        for file in git.stdout:
-            line = '0 0000000000000000000000000000000000000000\t{:s}\n'.format(file.strip())
-            yield line
-
 def _read_database_lastmod():
     with _spawn(
             args=['notmuch', 'count', '--lastmod', '*'],
@@ -1005,6 +1075,7 @@ if __name__ == '__main__':
             'merge',
             'pull',
             'push',
+            'reset',
             'status',
             ]:
         func = locals()[command]
@@ -1099,6 +1170,10 @@ if __name__ == '__main__':
                     'Refspec (usually a branch name) to push.  See '
                     'the <refspec> entry in the OPTIONS section of '
                     'git-push(1) for other possibilities.'))
+        elif command == 'reset':
+            subparser.add_argument(
+                '-f', '--force', action='store_true',
+                help='reset a large fraction of tags.')
 
     args = parser.parse_args()