Systems and methods for flexible parameter sharing for multi-task learning are provided. A training method can include obtaining a test input, selecting a particular task from one or more tasks, and training a multi-task machine-learned model for the particular task by performing a forward pass using the test input and one or more connection probability matrices to generate a sample distribution of test outputs, training the components of the machine-learned model based at least in part on the sample distribution, and performing a backwards pass to train a connection probability matrix of the multi-task machine-learned model using a straight-through Gumbel-softmax approximation.
Legal claims defining the scope of protection, as filed with the USPTO.
2. The computer-implemented method of claim 1, wherein the approximation comprises a straight-through Gumbel-Softmax approximation.
3. The computer-implemented method of claim 1, wherein the respective probability value is obtained from a connection probability matrix, wherein the connection probability matrix indicates, for the particular task, one or more connection probabilities respectively for the one or more components.
4. The computer-implemented method of claim 3, wherein, in a forward pass, the routing matrix is generated for the layer by sampling a respective binary value for a respective position in the routing matrix using a corresponding respective probability value from the respective position in the connection probability matrix.
5. The computer-implemented method of claim 4, wherein sampling the respective binary value comprises adding independent noise from a Gumbel distribution to each of a pair of complementary logits and selecting the respective binary value with the highest logit.
8. The computer-implemented method of claim 3, wherein an initial value for the connection probability matrix is selected to encourage or discourage a particular routing pathway.
9. The computer-implemented method of claim 1, wherein the routing matrix comprises one or more binary values respectively associated with a pairing of the one or more components and the particular task.
10. The computer-implemented method of claim 1, wherein outputs of the one or more components are aggregated before passing downstream.
17. The computing system of claim 16, wherein the routing matrix comprises binary values that activate connections to the one or more components.
18. The computing system of claim 17, wherein a respective value at a respective location of the routing matrix was obtained by selecting the maximum likelihood binary value associated with a corresponding value of the connection probability matrix.
19. The computing system of claim 17, wherein the binary values are indexed based on the particular task.
20. The computing system of claim 16, wherein learning the connection probability matrix comprises updating the connection probability matrix based on components of the layer not activated by the routing matrix.
Cooperative Patent Classification codes for this invention. Click any code to explore related patents in that topic.
March 17, 2020
February 27, 2024
Browse 5M+ US patents with plain-English claim translations and AI-generated analysis.