From e170e8454dfdc43f47e414f7b6ddc5733455c9fc Mon Sep 17 00:00:00 2001
From: Christopher Rhodes <christopher.rhodes@embl.de>
Date: Fri, 29 Nov 2024 16:37:31 +0100
Subject: [PATCH] Split out local and remot path roots, but successfully
 completed a test run on remote

---
 model_server/clients/batch_runner.py | 97 ++++++++++++++++++----------
 1 file changed, 63 insertions(+), 34 deletions(-)

diff --git a/model_server/clients/batch_runner.py b/model_server/clients/batch_runner.py
index 3d1d9ae6..d192cdd0 100644
--- a/model_server/clients/batch_runner.py
+++ b/model_server/clients/batch_runner.py
@@ -1,7 +1,6 @@
 from collections import OrderedDict
 import json
 from pathlib import Path
-import shutil
 
 import pandas as pd
 
@@ -21,34 +20,27 @@ class FileBatchRunnerClient(HttpClient):
         with open(conf_json, 'r') as fh:
             self.conf = json.load(fh)
 
-        self.paths = {
-            'conf': Path(self.conf['paths']['conf']),
-            'input': Path(self.conf['paths']['input']),
-            'output': Path(self.conf['paths']['output']),
-        }
-        for pa in self.paths.values():
-            pa.mkdir(parents=True, exist_ok=True)
+        self.local_paths = {k: Path(v) for k, v in self.conf['paths']['local'].items()}
+        self.remote_paths = {k: Path(v).as_posix() for k, v in self.conf['paths']['remote'].items()}
 
-        shutil.copy(conf_json, self.paths['output'] / conf_json.name)
+        for pa in self.local_paths.values():
+            pa.mkdir(parents=True, exist_ok=True)
 
         self.stacks = self.get_stacks(max_count=max_count)
-
         self.tasks = {}
-
         self.write_df()
 
         if not self.stacks['exists'].all():
-            raise FileNotFoundError(f'Found non-existent files, described in {self.pa_csv}')
+            raise FileNotFoundError(f'Trying to process non-existent image files')
         return super().__init__(**kwargs)
 
     def message(self, message):
         print(message)
 
     def write_df(self):
-        for rt in [self.conf_root, self.paths['output']]:
-            pa = rt / 'filelist.csv'
-            self.stacks.to_csv(pa)
-            self.message(f'Wrote stacks table to {pa}.')
+        pa = self.conf_root / 'filelist.csv'
+        self.stacks.to_csv(pa)
+        self.message(f'Wrote stacks table to {pa}.')
 
     def hit(self, method, endpoint, params=None, body=None, catch=True, **kwargs):
         resp = super(FileBatchRunnerClient, self).hit(method, endpoint, params=params, body=body)
@@ -70,17 +62,18 @@ class FileBatchRunnerClient(HttpClient):
             raise(e)
         self.message('Verified server is online at: ' + self.uri)
 
-    def watch_path(self, key, path, make=True, verify=False):
+    def watch_path(self, key, remote_path, local_path, make=True, verify=False, catch=False):
         if make:
-            path.mkdir(parents=True, exist_ok=True)
+            local_path.mkdir(parents=True, exist_ok=True)
 
         touch_uuid = self.hit(
             'put',
             f'/paths/watch_{key}',
-            params={'path': path.__str__(), 'touch': verify}
+            params={'path': remote_path.__str__(), 'touch': verify},
+            catch=catch
         )
         if verify:
-            pa_touch = path / 'svlt.touch'
+            pa_touch = local_path / 'svlt.touch'
             try:
                 with open(pa_touch, 'r') as fh:
                     cont = fh.read()
@@ -88,12 +81,33 @@ class FileBatchRunnerClient(HttpClient):
                 pa_touch.unlink()
             except Exception as e:
                 raise WatchPathVerificationError(e)
-        self.message(f'Watching {path} for {key} data')
-
-    def setup(self):
-        self.watch_path('conf', self.paths['conf'], verify=True, make=False)
-        self.watch_path('output', self.paths['output'], verify=True, make=True)
-        self.watch_path('input', self.paths['input'], verify=False, make=False)
+        self.message(f'Watching {remote_path} (remote), {local_path} (local) for {key} data')
+
+    def setup(self, catch=True,):
+        self.watch_path(
+            'conf',
+            self.remote_paths['conf'],
+            self.local_paths['conf'],
+            verify=False,
+            make=False,
+            catch=catch,
+        )
+        self.watch_path(
+            'output',
+            self.remote_paths['output'],
+            self.local_paths['output'],
+            verify=False,
+            make=True,
+            catch=catch,
+        )
+        self.watch_path(
+            'input',
+            self.remote_paths['input'],
+            self.local_paths['input'],
+            verify=False,
+            make=False,
+            catch=catch,
+        )
 
         for v in self.conf['setup']:
             resp = self.hit(**v, catch=False)
@@ -102,7 +116,7 @@ class FileBatchRunnerClient(HttpClient):
     def get_stacks(self, max_count=None):
         paths = []
         for inp in self.conf['inputs']:
-            loc = Path(self.paths['input']) / inp['directory']
+            where_local = Path(self.local_paths['input']) / inp['directory']
 
             # get explicit filenames
             files = inp.get('files', [])
@@ -111,20 +125,28 @@ class FileBatchRunnerClient(HttpClient):
             if pattern := inp.get('pattern'):
                 if pattern == '':
                     break
-                for f in list(loc.iterdir()):
+                for f in list(where_local.iterdir()):
                     if pattern.upper() in f.name.upper() and f.name not in files:
                         files.append(f.name)
             is_multiposition = inp.get('multiposition', False)
-            paths = paths + [{'path': loc / f, 'is_multiposition': is_multiposition} for f in files]
+            where_remote = Path(self.remote_paths['input']) / inp['directory']
+
+            def _get_file_info(filename):
+                return {
+                    'remote_path': (where_remote / filename).as_posix(),
+                    'local_path': where_local / filename,
+                    'is_multiposition': is_multiposition,
+                }
+            paths = paths + [_get_file_info(f) for f in files]
         if max_count is not None:
             df = pd.DataFrame(paths).head(min(max_count, len(paths)))
         else:
             df = pd.DataFrame(paths)
         if len(df) == 0:
             raise EmptyFileListError('No files were found')
-        df['exists'] = df['path'].apply(lambda x: x.exists())
-        df['parent'] = df['path'].apply(lambda x: x.parent)
-        df['filename'] = df['path'].apply(lambda x: x.name)
+        df['exists'] = df['local_path'].apply(lambda x: x.exists())
+        df['parent'] = df['local_path'].apply(lambda x: x.parent)
+        df['filename'] = df['local_path'].apply(lambda x: x.name)
         df['accessor_id'] = None
         self.message(f'Found {len(df)} input files.')
         return df
@@ -140,8 +162,15 @@ class FileBatchRunnerClient(HttpClient):
                 return None
 
         accessor_ids = []
-        for pa_dir, df_gb in self.stacks.groupby('parent'):
-            self.watch_path('input', self.paths['input'] / pa_dir, verify=False, make=False)
+        for loc_pa, df_gb in self.stacks.groupby('parent'):
+            pa_dir = loc_pa.relative_to(self.local_paths['input']).as_posix()
+            self.watch_path(
+                'input',
+                (Path(self.remote_paths['input']) / pa_dir).as_posix(),
+                self.local_paths['input'] / pa_dir,
+                verify=False,
+                make=False
+            )
             df_gb['accessor_id'] = df_gb.apply(_read, axis=1)
             df_gb['position'] = df_gb['accessor_id'].apply(lambda x: range(0, len(x)))
             df_gb = df_gb.explode(['accessor_id', 'position'])
-- 
GitLab