Usermode/ld-acess - Fix dynamic linking quirk (STB_WEAK and R_COPY)
[tpg/acess2.git] / Usermode / Libraries / ld-acess.so_src / loadlib.c
index dcacd8a..1c750a8 100644 (file)
@@ -4,6 +4,7 @@
 */
 #include "common.h"
 #include <stdint.h>
+#include <stdbool.h>
 #include <acess/sys.h>
 
 #define DEBUG  0
@@ -14,8 +15,7 @@
 # define DEBUGS(v...)  
 #endif
 
-// === PROTOTYPES ===
-void   *IsFileLoaded(const char *file);
+#define MAX_QUEUED_ENTRYPOINTS 8
 
 // === IMPORTS ===
 extern const struct {
@@ -24,14 +24,52 @@ extern const struct {
 }      caLocalExports[];
 extern const int       ciNumLocalExports;
 extern char    **gEnvP;
+extern char    gLinkedBase[];
+
+// === TYPES ===
+typedef void   tLibEntry(void *, int, char *[], char**);
+
+// === PROTOTYPES ===
+void   *IsFileLoaded(const char *file);
 
 // === GLOABLS ===
 tLoadedLib     gLoadedLibraries[MAX_LOADED_LIBRARIES];
 char   gsLoadedStrings[MAX_STRINGS_BYTES];
 char   *gsNextAvailString = gsLoadedStrings;
+struct sQueuedEntry {
+       void    *Base;
+       tLibEntry       *Entry;
+}      gaQueuedEntrypoints[MAX_QUEUED_ENTRYPOINTS];
+ int   giNumQueuedEntrypoints;
 //tLoadLib     *gpLoadedLibraries = NULL;
 
 // === CODE ===
+void ldacess_DumpLoadedLibraries(void)
+{
+       for( int i = 0; i < MAX_LOADED_LIBRARIES; i ++ )
+       {
+               if(gLoadedLibraries[i].Base == 0)       break;  // Last entry has Base set to NULL
+               _SysDebug("%p: %s",
+                       gLoadedLibraries[i].Base,
+                       gLoadedLibraries[i].Name
+                       );
+       }
+}
+
+/**
+ * \brief Call queued up entry points (after relocations completed) 
+ */
+void CallQueuedEntrypoints(char **EnvP)
+{
+       while( giNumQueuedEntrypoints )
+       {
+               giNumQueuedEntrypoints --;
+               const struct sQueuedEntry       *qe = &gaQueuedEntrypoints[giNumQueuedEntrypoints];
+               //_SysDebug("Calling EP for %p", qe->Base);
+               qe->Entry(qe->Base, 0, NULL, EnvP);
+       }
+}
+
 const char *FindLibrary(char *DestBuf, const char *SoName, const char *ExtraSearchDir)
 {      
        // -- #1: Executable Specified
@@ -59,14 +97,12 @@ const char *FindLibrary(char *DestBuf, const char *SoName, const char *ExtraSear
 void *LoadLibrary(const char *SoName, const char *SearchDir, char **envp)
 {
        char    sTmpName[1024];
-       const char      *filename;
        void    *base;
-       void    (*fEntry)(void *, int, char *[], char**);
        
        DEBUGS("LoadLibrary: (SoName='%s', SearchDir='%s', envp=%p)", SoName, SearchDir, envp);
        
        // Create Temp Name
-       filename = FindLibrary(sTmpName, SoName, SearchDir);
+       const char *filename = FindLibrary(sTmpName, SoName, SearchDir);
        if(filename == NULL) {
                DEBUGS("LoadLibrary: RETURN 0");
                return 0;
@@ -78,6 +114,7 @@ void *LoadLibrary(const char *SoName, const char *SearchDir, char **envp)
 
        DEBUGS(" LoadLibrary: SysLoadBin()");   
        // Load Library
+       tLibEntry       *fEntry;
        base = _SysLoadBin(filename, (void**)&fEntry);
        if(!base) {
                DEBUGS("LoadLibrary: RETURN 0");
@@ -93,10 +130,17 @@ void *LoadLibrary(const char *SoName, const char *SearchDir, char **envp)
        }
        
        // Call Entrypoint
-       DEBUGS(" LoadLibrary: '%s' Entry %p", SoName, fEntry);
-       fEntry(base, 0, NULL, gEnvP);
+       // - TODO: Queue entrypoint calls
+       if( giNumQueuedEntrypoints >= MAX_QUEUED_ENTRYPOINTS ) {
+               SysDebug("ERROR - Maximum number of queued entrypoints exceeded on %p '%s'",
+                       base, SoName);
+               return 0;
+       }
+       gaQueuedEntrypoints[giNumQueuedEntrypoints].Base  = base;
+       gaQueuedEntrypoints[giNumQueuedEntrypoints].Entry = fEntry;
+       giNumQueuedEntrypoints ++;
        
-       DEBUGS("LoadLibrary: RETURN 1");
+       DEBUGS("LoadLibrary: RETURN success");
        return base;
 }
 
@@ -108,6 +152,15 @@ void *IsFileLoaded(const char *file)
 {
         int    i;
        DEBUGS("IsFileLoaded: (file='%s')", file);
+
+       // Applications link against either libld-acess.so or ld-acess.so
+       if( strcmp(file, "/Acess/Libs/libld-acess.so") == 0
+        || strcmp(file, "/Acess/Libs/ld-acess.so") == 0 )
+       {
+               DEBUGS("IsFileLoaded: Found local (%p)", &gLinkedBase);
+               return &gLinkedBase;
+       }
+
        for( i = 0; i < MAX_LOADED_LIBRARIES; i++ )
        {
                if(gLoadedLibraries[i].Base == 0)       break;  // Last entry has Base set to NULL
@@ -201,34 +254,50 @@ void Unload(void *Base)
  \fn Uint GetSymbol(const char *name)
  \brief Gets a symbol value from a loaded library
 */
-int GetSymbol(const char *name, void **Value, size_t *Size)
+int GetSymbol(const char *name, void **Value, size_t *Size, void *IgnoreBase)
 {
-        int    i;
-       
-       //SysDebug("ciNumLocalExports = %i", ciNumLocalExports);
-       for(i=0;i<ciNumLocalExports;i++)
+       //SysDebug("GetSymbol: (%s)");
+       for( int i = 0; i < ciNumLocalExports; i ++ )
        {
                if( strcmp(caLocalExports[i].Name, name) == 0 ) {
                        *Value = caLocalExports[i].Value;
                        if(Size)
                                *Size = 0;
+                       //SysDebug("GetSymbol: Local %p+0x%x", *Value, 0);
                        return 1;
                }
        }
-       
-       // Entry 0 is ld-acess, ignore it
-       for(i = 1; i < MAX_LOADED_LIBRARIES; i ++)
+
+       bool have_weak = false; 
+       for(int i = 0; i < MAX_LOADED_LIBRARIES && gLoadedLibraries[i].Base != 0; i ++)
        {
-               if(gLoadedLibraries[i].Base == 0)
-                       break;
+               // Allow ignoring the current module
+               if( gLoadedLibraries[i].Base == IgnoreBase ) {
+                       //SysDebug("GetSymbol: Ignore %p", gLoadedLibraries[i].Base);
+                       continue ;
+               }
                
                //SysDebug(" GetSymbol: Trying 0x%x, '%s'",
                //      gLoadedLibraries[i].Base, gLoadedLibraries[i].Name);
-               if(GetSymbolFromBase(gLoadedLibraries[i].Base, name, Value, Size))
-                       return 1;
+               void    *tmpval;
+               size_t  tmpsize;
+               int rv = GetSymbolFromBase(gLoadedLibraries[i].Base, name, &tmpval, &tmpsize);
+               if(rv)
+               {
+                       *Value = tmpval;
+                       *Size = tmpsize;
+                       if( rv == 1 ) {
+                               return 1;
+                       }
+                       have_weak = true;
+               }
+       }
+       if(have_weak) {
+               return 2;
+       }
+       else {
+               return 0;
        }
-       SysDebug("GetSymbol: === Symbol '%s' not found ===", name);
-       return 0;
 }
 
 /**

UCC git Repository :: git.ucc.asn.au